diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java index 7376d1ddc4818..e04524dde0a75 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java @@ -30,10 +30,10 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.*; -import io.netty.util.AbstractReferenceCounted; import org.apache.commons.crypto.stream.CryptoInputStream; import org.apache.commons.crypto.stream.CryptoOutputStream; +import org.apache.spark.network.util.AbstractFileRegion; import org.apache.spark.network.util.ByteArrayReadableChannel; import org.apache.spark.network.util.ByteArrayWritableChannel; @@ -161,7 +161,7 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception { } } - private static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion { + private static class EncryptedMessage extends AbstractFileRegion { private final boolean isByteBuf; private final ByteBuf buf; private final FileRegion region; @@ -199,10 +199,45 @@ public long position() { } @Override - public long transfered() { + public long transferred() { return transferred; } + @Override + public EncryptedMessage touch(Object o) { + super.touch(o); + if (region != null) { + region.touch(o); + } + if (buf != null) { + buf.touch(o); + } + return this; + } + + @Override + public EncryptedMessage retain(int increment) { + super.retain(increment); + if (region != null) { + region.retain(increment); + } + if (buf != null) { + buf.retain(increment); + } + return this; + } + + @Override + public boolean release(int decrement) { + if (region != null) { + region.release(decrement); + } + if (buf != null) { + buf.release(decrement); + } + return super.release(decrement); + } + @Override public long transferTo(WritableByteChannel target, long position) throws IOException { Preconditions.checkArgument(position == transfered(), "Invalid position."); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java index 4f8781b42a0e4..897d0f9e4fb89 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -25,17 +25,17 @@ import com.google.common.base.Preconditions; import io.netty.buffer.ByteBuf; import io.netty.channel.FileRegion; -import io.netty.util.AbstractReferenceCounted; import io.netty.util.ReferenceCountUtil; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.util.AbstractFileRegion; /** * A wrapper message that holds two separate pieces (a header and a body). * * The header must be a ByteBuf, while the body can be a ByteBuf or a FileRegion. */ -class MessageWithHeader extends AbstractReferenceCounted implements FileRegion { +class MessageWithHeader extends AbstractFileRegion { @Nullable private final ManagedBuffer managedBuffer; private final ByteBuf header; @@ -91,7 +91,7 @@ public long position() { } @Override - public long transfered() { + public long transferred() { return totalBytesTransferred; } @@ -160,4 +160,37 @@ private int writeNioBuffer( return ret; } + + @Override + public MessageWithHeader touch(Object o) { + super.touch(o); + header.touch(o); + ReferenceCountUtil.touch(body, o); + return this; + } + + @Override + public MessageWithHeader retain(int increment) { + super.retain(increment); + header.retain(increment); + ReferenceCountUtil.retain(body, increment); + if (managedBuffer != null) { + for (int i = 0; i < increment; i++) { + managedBuffer.retain(); + } + } + return this; + } + + @Override + public boolean release(int decrement) { + header.release(decrement); + ReferenceCountUtil.release(body, decrement); + if (managedBuffer != null) { + for (int i = 0; i < decrement; i++) { + managedBuffer.release(); + } + } + return super.release(decrement); + } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java index 3d71ebaa7ea0c..16ab4efcd4f5f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java @@ -32,8 +32,8 @@ import io.netty.channel.ChannelPromise; import io.netty.channel.FileRegion; import io.netty.handler.codec.MessageToMessageDecoder; -import io.netty.util.AbstractReferenceCounted; +import org.apache.spark.network.util.AbstractFileRegion; import org.apache.spark.network.util.ByteArrayWritableChannel; import org.apache.spark.network.util.NettyUtils; @@ -129,7 +129,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List out) } @VisibleForTesting - static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion { + static class EncryptedMessage extends AbstractFileRegion { private final SaslEncryptionBackend backend; private final boolean isByteBuf; @@ -183,10 +183,45 @@ public long position() { * Returns an approximation of the amount of data transferred. See {@link #count()}. */ @Override - public long transfered() { + public long transferred() { return transferred; } + @Override + public EncryptedMessage touch(Object o) { + super.touch(o); + if (buf != null) { + buf.touch(o); + } + if (region != null) { + region.touch(o); + } + return this; + } + + @Override + public EncryptedMessage retain(int increment) { + super.retain(increment); + if (buf != null) { + buf.retain(increment); + } + if (region != null) { + region.retain(increment); + } + return this; + } + + @Override + public boolean release(int decrement) { + if (region != null) { + region.release(decrement); + } + if (buf != null) { + buf.release(decrement); + } + return super.release(decrement); + } + /** * Transfers data from the original message to the channel, encrypting it in the process. * diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/AbstractFileRegion.java b/common/network-common/src/main/java/org/apache/spark/network/util/AbstractFileRegion.java new file mode 100644 index 0000000000000..8651297d97ec2 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/AbstractFileRegion.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; + +public abstract class AbstractFileRegion extends AbstractReferenceCounted implements FileRegion { + + @Override + @SuppressWarnings("deprecation") + public final long transfered() { + return transferred(); + } + + @Override + public AbstractFileRegion retain() { + super.retain(); + return this; + } + + @Override + public AbstractFileRegion retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public AbstractFileRegion touch() { + super.touch(); + return this; + } + + @Override + public AbstractFileRegion touch(Object o) { + return this; + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java index bb1c40c4b0e06..bc94f7ca63a96 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -56,7 +56,7 @@ private void testServerToClient(Message msg) { NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE); while (!serverChannel.outboundMessages().isEmpty()) { - clientChannel.writeInbound(serverChannel.readOutbound()); + clientChannel.writeOneInbound(serverChannel.readOutbound()); } assertEquals(1, clientChannel.inboundMessages().size()); @@ -72,7 +72,7 @@ private void testClientToServer(Message msg) { NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE); while (!clientChannel.outboundMessages().isEmpty()) { - serverChannel.writeInbound(clientChannel.readOutbound()); + serverChannel.writeOneInbound(clientChannel.readOutbound()); } assertEquals(1, serverChannel.inboundMessages().size()); diff --git a/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java index b341c5681e00c..ecb66fcf2ff76 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -23,8 +23,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; -import io.netty.channel.FileRegion; -import io.netty.util.AbstractReferenceCounted; +import org.apache.spark.network.util.AbstractFileRegion; import org.junit.Test; import org.mockito.Mockito; @@ -108,7 +107,7 @@ private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exc return Unpooled.wrappedBuffer(channel.getData()); } - private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion { + private static class TestFileRegion extends AbstractFileRegion { private final int writeCount; private final int writesPerCall; @@ -130,7 +129,7 @@ public long position() { } @Override - public long transfered() { + public long transferred() { return 8 * written; } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 97abd92d4b70f..39249d411b582 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -26,12 +26,11 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.ListBuffer import com.google.common.io.Closeables -import io.netty.channel.{DefaultFileRegion, FileRegion} -import io.netty.util.AbstractReferenceCounted +import io.netty.channel.DefaultFileRegion import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils +import org.apache.spark.network.util.{AbstractFileRegion, JavaUtils} import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBuffer @@ -266,7 +265,7 @@ private class EncryptedBlockData( } private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: Long) - extends AbstractReferenceCounted with FileRegion { + extends AbstractFileRegion { private var _transferred = 0L @@ -277,7 +276,7 @@ private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: override def position(): Long = 0 - override def transfered(): Long = _transferred + override def transferred(): Long = _transferred override def transferTo(target: WritableByteChannel, pos: Long): Long = { assert(pos == transfered(), "Invalid position.") diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 2c68b73095c4d..82b9c4a603355 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -14,9 +14,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -arrow-format-0.4.0.jar -arrow-memory-0.4.0.jar -arrow-vector-0.4.0.jar +arrow-format-0.8.0.jar +arrow-memory-0.8.0.jar +arrow-vector-0.8.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -82,7 +82,7 @@ hadoop-yarn-server-web-proxy-2.6.5.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar -hppc-0.7.1.jar +hppc-0.7.2.jar htrace-core-3.0.4.jar httpclient-4.5.2.jar httpcore-4.4.4.jar @@ -144,7 +144,7 @@ metrics-json-3.1.5.jar metrics-jvm-3.1.5.jar minlog-1.3.0.jar netty-3.9.9.Final.jar -netty-all-4.0.47.Final.jar +netty-all-4.1.17.Final.jar objenesis-2.1.jar opencsv-2.3.jar orc-core-1.4.1-nohive.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 2aaac600b3ec3..0795ea41b2d7f 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -14,9 +14,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -arrow-format-0.4.0.jar -arrow-memory-0.4.0.jar -arrow-vector-0.4.0.jar +arrow-format-0.8.0.jar +arrow-memory-0.8.0.jar +arrow-vector-0.8.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -82,7 +82,7 @@ hadoop-yarn-server-web-proxy-2.7.3.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar -hppc-0.7.1.jar +hppc-0.7.2.jar htrace-core-3.1.0-incubating.jar httpclient-4.5.2.jar httpcore-4.4.4.jar @@ -145,7 +145,7 @@ metrics-json-3.1.5.jar metrics-jvm-3.1.5.jar minlog-1.3.0.jar netty-3.9.9.Final.jar -netty-all-4.0.47.Final.jar +netty-all-4.1.17.Final.jar objenesis-2.1.jar opencsv-2.3.jar orc-core-1.4.1-nohive.jar diff --git a/pom.xml b/pom.xml index 07bca9d267da0..d7a95ffd64452 100644 --- a/pom.xml +++ b/pom.xml @@ -185,7 +185,7 @@ 2.8 1.8 1.0.0 - 0.4.0 + 0.8.0 ${java.home} @@ -580,7 +580,7 @@ io.netty netty-all - 4.0.47.Final + 4.1.17.Final io.netty @@ -1972,6 +1972,14 @@ com.fasterxml.jackson.core jackson-databind + + io.netty + netty-buffer + + + io.netty + netty-common + io.netty netty-handler diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 37e7cf3fa662e..88d6a191babca 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -223,27 +223,14 @@ def _create_batch(series, timezone): series = [series] series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - # If a nullable integer series has been promoted to floating point with NaNs, need to cast - # NOTE: this is not necessary with Arrow >= 0.7 - def cast_series(s, t): - if type(t) == pa.TimestampType: - # NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680 - return _check_series_convert_timestamps_internal(s.fillna(0), timezone)\ - .values.astype('datetime64[us]', copy=False) - # NOTE: can not compare None with pyarrow.DataType(), fixed with Arrow >= 0.7.1 - elif t is not None and t == pa.date32(): - # TODO: this converts the series to Python objects, possibly avoid with Arrow >= 0.8 - return s.dt.date - elif t is None or s.dtype == t.to_pandas_dtype(): - return s - else: - return s.fillna(0).astype(t.to_pandas_dtype(), copy=False) - - # Some object types don't support masks in Arrow, see ARROW-1721 def create_array(s, t): - casted = cast_series(s, t) - mask = None if casted.dtype == 'object' else s.isnull() - return pa.Array.from_pandas(casted, mask=mask, type=t) + mask = s.isnull() + # Ensure timestamp series are in expected form for Spark internal representation + if t is not None and pa.types.is_timestamp(t): + s = _check_series_convert_timestamps_internal(s.fillna(0), timezone) + # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2 + return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False) + return pa.Array.from_pandas(s, mask=mask, type=t) arrs = [create_array(s, t) for s, t in series] return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 9864dc98c1f33..7e3710912b3f9 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1892,7 +1892,9 @@ def toPandas(self): if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: from pyspark.sql.types import _check_dataframe_localize_timestamps + from pyspark.sql.utils import _require_minimum_pyarrow_version import pyarrow + _require_minimum_pyarrow_version() tables = self._collectAsArrow() if tables: table = pyarrow.concat_tables(tables) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 4e0faddb1c0df..f79bdc6230545 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2141,16 +2141,17 @@ def pandas_udf(f=None, returnType=None, functionType=None): >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> from pyspark.sql.types import IntegerType, StringType - >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) - >>> @pandas_udf(StringType()) + >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) # doctest: +SKIP + >>> @pandas_udf(StringType()) # doctest: +SKIP ... def to_upper(s): ... return s.str.upper() ... - >>> @pandas_udf("integer", PandasUDFType.SCALAR) + >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP ... def add_one(x): ... return x + 1 ... - >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) + >>> df = spark.createDataFrame([(1, "John Doe", 21)], + ... ("id", "name", "age")) # doctest: +SKIP >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ ... .show() # doctest: +SKIP +----------+--------------+------------+ @@ -2171,8 +2172,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], - ... ("id", "v")) - >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) + ... ("id", "v")) # doctest: +SKIP + >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) # doctest: +SKIP ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 4d47dd6a3e878..09fae46adf014 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -218,7 +218,7 @@ def apply(self, udf): >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) - >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) + >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) # doctest: +SKIP ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e2435e09af23d..86db16eca7889 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -495,11 +495,14 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): from pyspark.serializers import ArrowSerializer, _create_batch from pyspark.sql.types import from_arrow_schema, to_arrow_type, \ _old_pandas_exception_message, TimestampType + from pyspark.sql.utils import _require_minimum_pyarrow_version try: from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype except ImportError as e: raise ImportError(_old_pandas_exception_message(e)) + _require_minimum_pyarrow_version() + # Determine arrow types to coerce data when creating batches if isinstance(schema, StructType): arrow_types = [to_arrow_type(f.dataType) for f in schema.fields] diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b4d32d8de8a22..6fdfda1cc831b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3339,10 +3339,11 @@ def test_createDataFrame_with_single_data_type(self): self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int") def test_createDataFrame_does_not_modify_input(self): + import pandas as pd # Some series get converted for Spark to consume, this makes sure input is unchanged pdf = self.create_pandas_data_frame() # Use a nanosecond value to make sure it is not truncated - pdf.ix[0, '7_timestamp_t'] = 1 + pdf.ix[0, '7_timestamp_t'] = pd.Timestamp(1) # Integers with nulls will get NaNs filled with 0 and will be casted pdf.ix[1, '2_int_t'] = None pdf_copy = pdf.copy(deep=True) @@ -3356,6 +3357,7 @@ def test_schema_conversion_roundtrip(self): self.assertEquals(self.schema, schema_rt) +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class PandasUDFTests(ReusedSQLTestCase): def test_pandas_udf_basic(self): from pyspark.rdd import PythonEvalType @@ -3671,9 +3673,9 @@ def test_vectorized_udf_chained(self): def test_vectorized_udf_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) - f = pandas_udf(lambda x: x * 1.0, StringType()) + f = pandas_udf(lambda x: x * 1.0, ArrayType(LongType())) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Invalid.*type'): + with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'): df.select(f(col('id'))).collect() def test_vectorized_udf_return_scalar(self): @@ -3974,12 +3976,12 @@ def test_wrong_return_type(self): foo = pandas_udf( lambda pdf: pdf, - 'id long, v string', + 'id long, v array', PandasUDFType.GROUP_MAP ) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Invalid.*type'): + with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'): df.groupby('id').apply(foo).sort('id').toPandas() def test_wrong_args(self): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 78abc32a35a1c..46d9a417414b5 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1642,29 +1642,28 @@ def to_arrow_schema(schema): def from_arrow_type(at): """ Convert pyarrow type to Spark data type. """ - # TODO: newer pyarrow has is_boolean(at) functions that would be better to check type - import pyarrow as pa - if at == pa.bool_(): + import pyarrow.types as types + if types.is_boolean(at): spark_type = BooleanType() - elif at == pa.int8(): + elif types.is_int8(at): spark_type = ByteType() - elif at == pa.int16(): + elif types.is_int16(at): spark_type = ShortType() - elif at == pa.int32(): + elif types.is_int32(at): spark_type = IntegerType() - elif at == pa.int64(): + elif types.is_int64(at): spark_type = LongType() - elif at == pa.float32(): + elif types.is_float32(at): spark_type = FloatType() - elif at == pa.float64(): + elif types.is_float64(at): spark_type = DoubleType() - elif type(at) == pa.DecimalType: + elif types.is_decimal(at): spark_type = DecimalType(precision=at.precision, scale=at.scale) - elif at == pa.string(): + elif types.is_string(at): spark_type = StringType() - elif at == pa.date32(): + elif types.is_date32(at): spark_type = DateType() - elif type(at) == pa.TimestampType: + elif types.is_timestamp(at): spark_type = TimestampType() else: raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index c3301a41ccd5a..50c87ba1ac882 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -33,19 +33,23 @@ def _wrap_function(sc, func, returnType): def _create_udf(f, returnType, evalType): - if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF: + + if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF or \ + evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: import inspect + from pyspark.sql.utils import _require_minimum_pyarrow_version + + _require_minimum_pyarrow_version() argspec = inspect.getargspec(f) - if len(argspec.args) == 0 and argspec.varargs is None: + + if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF and len(argspec.args) == 0 and \ + argspec.varargs is None: raise ValueError( "Invalid function: 0-arg pandas_udfs are not supported. " "Instead, create a 1-arg pandas_udf and ignore the arg in your function." ) - elif evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: - import inspect - argspec = inspect.getargspec(f) - if len(argspec.args) != 1: + if evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF and len(argspec.args) != 1: raise ValueError( "Invalid function: pandas_udfs with function type GROUP_MAP " "must take a single arg that is a pandas DataFrame." diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 7bc6a59ad3b26..cc7dabb64b3ec 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -110,3 +110,12 @@ def toJArray(gateway, jtype, arr): for i in range(0, len(arr)): jarr[i] = arr[i] return jarr + + +def _require_minimum_pyarrow_version(): + """ Raise ImportError if minimum version of pyarrow is not installed + """ + from distutils.version import LooseVersion + import pyarrow + if LooseVersion(pyarrow.__version__) < LooseVersion('0.8.0'): + raise ImportError("pyarrow >= 0.8.0 must be installed on calling Python process") diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 0071bd66760be..73569589e1599 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -285,30 +285,30 @@ public byte[] getBinary(int rowId) { public ArrowColumnVector(ValueVector vector) { super(ArrowUtils.fromArrowField(vector.getField())); - if (vector instanceof NullableBitVector) { - accessor = new BooleanAccessor((NullableBitVector) vector); - } else if (vector instanceof NullableTinyIntVector) { - accessor = new ByteAccessor((NullableTinyIntVector) vector); - } else if (vector instanceof NullableSmallIntVector) { - accessor = new ShortAccessor((NullableSmallIntVector) vector); - } else if (vector instanceof NullableIntVector) { - accessor = new IntAccessor((NullableIntVector) vector); - } else if (vector instanceof NullableBigIntVector) { - accessor = new LongAccessor((NullableBigIntVector) vector); - } else if (vector instanceof NullableFloat4Vector) { - accessor = new FloatAccessor((NullableFloat4Vector) vector); - } else if (vector instanceof NullableFloat8Vector) { - accessor = new DoubleAccessor((NullableFloat8Vector) vector); - } else if (vector instanceof NullableDecimalVector) { - accessor = new DecimalAccessor((NullableDecimalVector) vector); - } else if (vector instanceof NullableVarCharVector) { - accessor = new StringAccessor((NullableVarCharVector) vector); - } else if (vector instanceof NullableVarBinaryVector) { - accessor = new BinaryAccessor((NullableVarBinaryVector) vector); - } else if (vector instanceof NullableDateDayVector) { - accessor = new DateAccessor((NullableDateDayVector) vector); - } else if (vector instanceof NullableTimeStampMicroTZVector) { - accessor = new TimestampAccessor((NullableTimeStampMicroTZVector) vector); + if (vector instanceof BitVector) { + accessor = new BooleanAccessor((BitVector) vector); + } else if (vector instanceof TinyIntVector) { + accessor = new ByteAccessor((TinyIntVector) vector); + } else if (vector instanceof SmallIntVector) { + accessor = new ShortAccessor((SmallIntVector) vector); + } else if (vector instanceof IntVector) { + accessor = new IntAccessor((IntVector) vector); + } else if (vector instanceof BigIntVector) { + accessor = new LongAccessor((BigIntVector) vector); + } else if (vector instanceof Float4Vector) { + accessor = new FloatAccessor((Float4Vector) vector); + } else if (vector instanceof Float8Vector) { + accessor = new DoubleAccessor((Float8Vector) vector); + } else if (vector instanceof DecimalVector) { + accessor = new DecimalAccessor((DecimalVector) vector); + } else if (vector instanceof VarCharVector) { + accessor = new StringAccessor((VarCharVector) vector); + } else if (vector instanceof VarBinaryVector) { + accessor = new BinaryAccessor((VarBinaryVector) vector); + } else if (vector instanceof DateDayVector) { + accessor = new DateAccessor((DateDayVector) vector); + } else if (vector instanceof TimeStampMicroTZVector) { + accessor = new TimestampAccessor((TimeStampMicroTZVector) vector); } else if (vector instanceof ListVector) { ListVector listVector = (ListVector) vector; accessor = new ArrayAccessor(listVector); @@ -332,23 +332,21 @@ public ArrowColumnVector(ValueVector vector) { private abstract static class ArrowVectorAccessor { private final ValueVector vector; - private final ValueVector.Accessor nulls; ArrowVectorAccessor(ValueVector vector) { this.vector = vector; - this.nulls = vector.getAccessor(); } final boolean isNullAt(int rowId) { - return nulls.isNull(rowId); + return vector.isNull(rowId); } final int getValueCount() { - return nulls.getValueCount(); + return vector.getValueCount(); } final int getNullCount() { - return nulls.getNullCount(); + return vector.getNullCount(); } final void close() { @@ -406,11 +404,11 @@ int getArrayOffset(int rowId) { private static class BooleanAccessor extends ArrowVectorAccessor { - private final NullableBitVector.Accessor accessor; + private final BitVector accessor; - BooleanAccessor(NullableBitVector vector) { + BooleanAccessor(BitVector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -421,11 +419,11 @@ final boolean getBoolean(int rowId) { private static class ByteAccessor extends ArrowVectorAccessor { - private final NullableTinyIntVector.Accessor accessor; + private final TinyIntVector accessor; - ByteAccessor(NullableTinyIntVector vector) { + ByteAccessor(TinyIntVector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -436,11 +434,11 @@ final byte getByte(int rowId) { private static class ShortAccessor extends ArrowVectorAccessor { - private final NullableSmallIntVector.Accessor accessor; + private final SmallIntVector accessor; - ShortAccessor(NullableSmallIntVector vector) { + ShortAccessor(SmallIntVector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -451,11 +449,11 @@ final short getShort(int rowId) { private static class IntAccessor extends ArrowVectorAccessor { - private final NullableIntVector.Accessor accessor; + private final IntVector accessor; - IntAccessor(NullableIntVector vector) { + IntAccessor(IntVector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -466,11 +464,11 @@ final int getInt(int rowId) { private static class LongAccessor extends ArrowVectorAccessor { - private final NullableBigIntVector.Accessor accessor; + private final BigIntVector accessor; - LongAccessor(NullableBigIntVector vector) { + LongAccessor(BigIntVector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -481,11 +479,11 @@ final long getLong(int rowId) { private static class FloatAccessor extends ArrowVectorAccessor { - private final NullableFloat4Vector.Accessor accessor; + private final Float4Vector accessor; - FloatAccessor(NullableFloat4Vector vector) { + FloatAccessor(Float4Vector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -496,11 +494,11 @@ final float getFloat(int rowId) { private static class DoubleAccessor extends ArrowVectorAccessor { - private final NullableFloat8Vector.Accessor accessor; + private final Float8Vector accessor; - DoubleAccessor(NullableFloat8Vector vector) { + DoubleAccessor(Float8Vector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -511,11 +509,11 @@ final double getDouble(int rowId) { private static class DecimalAccessor extends ArrowVectorAccessor { - private final NullableDecimalVector.Accessor accessor; + private final DecimalVector accessor; - DecimalAccessor(NullableDecimalVector vector) { + DecimalAccessor(DecimalVector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -527,12 +525,12 @@ final Decimal getDecimal(int rowId, int precision, int scale) { private static class StringAccessor extends ArrowVectorAccessor { - private final NullableVarCharVector.Accessor accessor; + private final VarCharVector accessor; private final NullableVarCharHolder stringResult = new NullableVarCharHolder(); - StringAccessor(NullableVarCharVector vector) { + StringAccessor(VarCharVector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -550,11 +548,11 @@ final UTF8String getUTF8String(int rowId) { private static class BinaryAccessor extends ArrowVectorAccessor { - private final NullableVarBinaryVector.Accessor accessor; + private final VarBinaryVector accessor; - BinaryAccessor(NullableVarBinaryVector vector) { + BinaryAccessor(VarBinaryVector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -565,11 +563,11 @@ final byte[] getBinary(int rowId) { private static class DateAccessor extends ArrowVectorAccessor { - private final NullableDateDayVector.Accessor accessor; + private final DateDayVector accessor; - DateAccessor(NullableDateDayVector vector) { + DateAccessor(DateDayVector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -580,11 +578,11 @@ final int getInt(int rowId) { private static class TimestampAccessor extends ArrowVectorAccessor { - private final NullableTimeStampMicroTZVector.Accessor accessor; + private final TimeStampMicroTZVector accessor; - TimestampAccessor(NullableTimeStampMicroTZVector vector) { + TimestampAccessor(TimeStampMicroTZVector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -595,21 +593,21 @@ final long getLong(int rowId) { private static class ArrayAccessor extends ArrowVectorAccessor { - private final UInt4Vector.Accessor accessor; + private final ListVector accessor; ArrayAccessor(ListVector vector) { super(vector); - this.accessor = vector.getOffsetVector().getAccessor(); + this.accessor = vector; } @Override final int getArrayLength(int rowId) { - return accessor.get(rowId + 1) - accessor.get(rowId); + return accessor.getInnerValueCountAt(rowId); } @Override final int getArrayOffset(int rowId) { - return accessor.get(rowId); + return accessor.getOffsetBuffer().getInt(rowId * accessor.OFFSET_WIDTH); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 3cafb344ef553..bcfc412430263 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -24,8 +24,8 @@ import scala.collection.JavaConverters._ import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ -import org.apache.arrow.vector.file._ -import org.apache.arrow.vector.schema.ArrowRecordBatch +import org.apache.arrow.vector.ipc.{ArrowFileReader, ArrowFileWriter} +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel import org.apache.spark.TaskContext @@ -86,13 +86,9 @@ private[sql] object ArrowConverters { val root = VectorSchemaRoot.create(arrowSchema, allocator) val arrowWriter = ArrowWriter.create(root) - var closed = false - context.addTaskCompletionListener { _ => - if (!closed) { - root.close() - allocator.close() - } + root.close() + allocator.close() } new Iterator[ArrowPayload] { @@ -100,7 +96,6 @@ private[sql] object ArrowConverters { override def hasNext: Boolean = rowIter.hasNext || { root.close() allocator.close() - closed = true false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index e4af4f65da127..0258056d9de49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -46,17 +46,17 @@ object ArrowWriter { private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { val field = vector.getField() (ArrowUtils.fromArrowField(field), vector) match { - case (BooleanType, vector: NullableBitVector) => new BooleanWriter(vector) - case (ByteType, vector: NullableTinyIntVector) => new ByteWriter(vector) - case (ShortType, vector: NullableSmallIntVector) => new ShortWriter(vector) - case (IntegerType, vector: NullableIntVector) => new IntegerWriter(vector) - case (LongType, vector: NullableBigIntVector) => new LongWriter(vector) - case (FloatType, vector: NullableFloat4Vector) => new FloatWriter(vector) - case (DoubleType, vector: NullableFloat8Vector) => new DoubleWriter(vector) - case (StringType, vector: NullableVarCharVector) => new StringWriter(vector) - case (BinaryType, vector: NullableVarBinaryVector) => new BinaryWriter(vector) - case (DateType, vector: NullableDateDayVector) => new DateWriter(vector) - case (TimestampType, vector: NullableTimeStampMicroTZVector) => new TimestampWriter(vector) + case (BooleanType, vector: BitVector) => new BooleanWriter(vector) + case (ByteType, vector: TinyIntVector) => new ByteWriter(vector) + case (ShortType, vector: SmallIntVector) => new ShortWriter(vector) + case (IntegerType, vector: IntVector) => new IntegerWriter(vector) + case (LongType, vector: BigIntVector) => new LongWriter(vector) + case (FloatType, vector: Float4Vector) => new FloatWriter(vector) + case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector) + case (StringType, vector: VarCharVector) => new StringWriter(vector) + case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector) + case (DateType, vector: DateDayVector) => new DateWriter(vector) + case (TimestampType, vector: TimeStampMicroTZVector) => new TimestampWriter(vector) case (ArrayType(_, _), vector: ListVector) => val elementVector = createFieldWriter(vector.getDataVector()) new ArrayWriter(vector, elementVector) @@ -103,7 +103,6 @@ class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { private[arrow] abstract class ArrowFieldWriter { def valueVector: ValueVector - def valueMutator: ValueVector.Mutator def name: String = valueVector.getField().getName() def dataType: DataType = ArrowUtils.fromArrowField(valueVector.getField()) @@ -124,161 +123,144 @@ private[arrow] abstract class ArrowFieldWriter { } def finish(): Unit = { - valueMutator.setValueCount(count) + valueVector.setValueCount(count) } def reset(): Unit = { - valueMutator.reset() + // TODO: reset() should be in a common interface + valueVector match { + case fixedWidthVector: BaseFixedWidthVector => fixedWidthVector.reset() + case variableWidthVector: BaseVariableWidthVector => variableWidthVector.reset() + case _ => + } count = 0 } } -private[arrow] class BooleanWriter(val valueVector: NullableBitVector) extends ArrowFieldWriter { - - override def valueMutator: NullableBitVector#Mutator = valueVector.getMutator() +private[arrow] class BooleanWriter(val valueVector: BitVector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - valueMutator.setSafe(count, if (input.getBoolean(ordinal)) 1 else 0) + valueVector.setSafe(count, if (input.getBoolean(ordinal)) 1 else 0) } } -private[arrow] class ByteWriter(val valueVector: NullableTinyIntVector) extends ArrowFieldWriter { - - override def valueMutator: NullableTinyIntVector#Mutator = valueVector.getMutator() +private[arrow] class ByteWriter(val valueVector: TinyIntVector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - valueMutator.setSafe(count, input.getByte(ordinal)) + valueVector.setSafe(count, input.getByte(ordinal)) } } -private[arrow] class ShortWriter(val valueVector: NullableSmallIntVector) extends ArrowFieldWriter { - - override def valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator() +private[arrow] class ShortWriter(val valueVector: SmallIntVector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - valueMutator.setSafe(count, input.getShort(ordinal)) + valueVector.setSafe(count, input.getShort(ordinal)) } } -private[arrow] class IntegerWriter(val valueVector: NullableIntVector) extends ArrowFieldWriter { - - override def valueMutator: NullableIntVector#Mutator = valueVector.getMutator() +private[arrow] class IntegerWriter(val valueVector: IntVector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - valueMutator.setSafe(count, input.getInt(ordinal)) + valueVector.setSafe(count, input.getInt(ordinal)) } } -private[arrow] class LongWriter(val valueVector: NullableBigIntVector) extends ArrowFieldWriter { - - override def valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator() +private[arrow] class LongWriter(val valueVector: BigIntVector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - valueMutator.setSafe(count, input.getLong(ordinal)) + valueVector.setSafe(count, input.getLong(ordinal)) } } -private[arrow] class FloatWriter(val valueVector: NullableFloat4Vector) extends ArrowFieldWriter { - - override def valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator() +private[arrow] class FloatWriter(val valueVector: Float4Vector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - valueMutator.setSafe(count, input.getFloat(ordinal)) + valueVector.setSafe(count, input.getFloat(ordinal)) } } -private[arrow] class DoubleWriter(val valueVector: NullableFloat8Vector) extends ArrowFieldWriter { - - override def valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator() +private[arrow] class DoubleWriter(val valueVector: Float8Vector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - valueMutator.setSafe(count, input.getDouble(ordinal)) + valueVector.setSafe(count, input.getDouble(ordinal)) } } -private[arrow] class StringWriter(val valueVector: NullableVarCharVector) extends ArrowFieldWriter { - - override def valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator() +private[arrow] class StringWriter(val valueVector: VarCharVector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { val utf8 = input.getUTF8String(ordinal) val utf8ByteBuffer = utf8.getByteBuffer // todo: for off-heap UTF8String, how to pass in to arrow without copy? - valueMutator.setSafe(count, utf8ByteBuffer, utf8ByteBuffer.position(), utf8.numBytes()) + valueVector.setSafe(count, utf8ByteBuffer, utf8ByteBuffer.position(), utf8.numBytes()) } } private[arrow] class BinaryWriter( - val valueVector: NullableVarBinaryVector) extends ArrowFieldWriter { - - override def valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator() + val valueVector: VarBinaryVector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { val bytes = input.getBinary(ordinal) - valueMutator.setSafe(count, bytes, 0, bytes.length) + valueVector.setSafe(count, bytes, 0, bytes.length) } } -private[arrow] class DateWriter(val valueVector: NullableDateDayVector) extends ArrowFieldWriter { - - override def valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator() +private[arrow] class DateWriter(val valueVector: DateDayVector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - valueMutator.setSafe(count, input.getInt(ordinal)) + valueVector.setSafe(count, input.getInt(ordinal)) } } private[arrow] class TimestampWriter( - val valueVector: NullableTimeStampMicroTZVector) extends ArrowFieldWriter { - - override def valueMutator: NullableTimeStampMicroTZVector#Mutator = valueVector.getMutator() + val valueVector: TimeStampMicroTZVector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - valueMutator.setSafe(count, input.getLong(ordinal)) + valueVector.setSafe(count, input.getLong(ordinal)) } } @@ -286,20 +268,18 @@ private[arrow] class ArrayWriter( val valueVector: ListVector, val elementWriter: ArrowFieldWriter) extends ArrowFieldWriter { - override def valueMutator: ListVector#Mutator = valueVector.getMutator() - override def setNull(): Unit = { } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { val array = input.getArray(ordinal) var i = 0 - valueMutator.startNewValue(count) + valueVector.startNewValue(count) while (i < array.numElements()) { elementWriter.write(array, i) i += 1 } - valueMutator.endValue(count, array.numElements()) + valueVector.endValue(count, array.numElements()) } override def finish(): Unit = { @@ -317,8 +297,6 @@ private[arrow] class StructWriter( val valueVector: NullableMapVector, children: Array[ArrowFieldWriter]) extends ArrowFieldWriter { - override def valueMutator: NullableMapVector#Mutator = valueVector.getMutator() - override def setNull(): Unit = { var i = 0 while (i < children.length) { @@ -326,7 +304,7 @@ private[arrow] class StructWriter( children(i).count += 1 i += 1 } - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { @@ -336,7 +314,7 @@ private[arrow] class StructWriter( children(i).write(struct, i) i += 1 } - valueMutator.setIndexDefined(count) + valueVector.setIndexDefined(count) } override def finish(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 9a94d771a01b0..5cc8ed3535654 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -24,7 +24,7 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import org.apache.arrow.vector.VectorSchemaRoot -import org.apache.arrow.vector.stream.{ArrowStreamReader, ArrowStreamWriter} +import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} import org.apache.spark._ import org.apache.spark.api.python._ @@ -74,13 +74,9 @@ class ArrowPythonRunner( val root = VectorSchemaRoot.create(arrowSchema, allocator) val arrowWriter = ArrowWriter.create(root) - var closed = false - context.addTaskCompletionListener { _ => - if (!closed) { - root.close() - allocator.close() - } + root.close() + allocator.close() } val writer = new ArrowStreamWriter(root, null, dataOut) @@ -102,7 +98,6 @@ class ArrowPythonRunner( writer.end() root.close() allocator.close() - closed = true } } } @@ -126,18 +121,11 @@ class ArrowPythonRunner( private var schema: StructType = _ private var vectors: Array[ColumnVector] = _ - private var closed = false - context.addTaskCompletionListener { _ => - // todo: we need something like `reader.end()`, which release all the resources, but leave - // the input stream open. `reader.close()` will close the socket and we can't reuse worker. - // So here we simply not close the reader, which is problematic. - if (!closed) { - if (root != null) { - root.close() - } - allocator.close() + if (reader != null) { + reader.close(false) } + allocator.close() } private var batchLoaded = true @@ -154,9 +142,8 @@ class ArrowPythonRunner( batch.setNumRows(root.getRowCount) batch } else { - root.close() + reader.close(false) allocator.close() - closed = true // Reach end of stream. Call `read()` again to read control data. read() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 57958f7239224..fd5a3df6abc68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -25,7 +25,7 @@ import java.util.Locale import com.google.common.io.Files import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} -import org.apache.arrow.vector.file.json.JsonFileReader +import org.apache.arrow.vector.ipc.JsonFileReader import org.apache.arrow.vector.util.Validator import org.scalatest.BeforeAndAfterAll @@ -76,16 +76,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 16 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 16 - | } ] - | } + | "children" : [ ] | }, { | "name" : "b_s", | "type" : { @@ -94,16 +85,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 16 | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 16 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -143,16 +125,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | }, { | "name" : "b_i", | "type" : { @@ -161,16 +134,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -210,16 +174,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 64 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } + | "children" : [ ] | }, { | "name" : "b_l", | "type" : { @@ -228,16 +183,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 64 | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -276,16 +222,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "precision" : "SINGLE" | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | }, { | "name" : "b_f", | "type" : { @@ -293,16 +230,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "precision" : "SINGLE" | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -341,16 +269,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "precision" : "DOUBLE" | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } + | "children" : [ ] | }, { | "name" : "b_d", | "type" : { @@ -358,16 +277,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "precision" : "DOUBLE" | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -408,16 +318,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -449,16 +350,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 16 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 16 - | } ] - | } + | "children" : [ ] | }, { | "name" : "b", | "type" : { @@ -466,16 +358,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "precision" : "SINGLE" | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | }, { | "name" : "c", | "type" : { @@ -484,16 +367,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | }, { | "name" : "d", | "type" : { @@ -501,16 +375,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "precision" : "DOUBLE" | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } + | "children" : [ ] | }, { | "name" : "e", | "type" : { @@ -519,16 +384,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 64 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -583,57 +439,21 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "name" : "utf8" | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } + | "children" : [ ] | }, { | "name" : "lower_case", | "type" : { | "name" : "utf8" | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } + | "children" : [ ] | }, { | "name" : "null_str", | "type" : { | "name" : "utf8" | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -681,16 +501,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "name" : "bool" | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 1 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -721,16 +532,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 8 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -760,19 +562,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "name" : "binary" | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -807,16 +597,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "unit" : "DAY" | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -855,16 +636,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "timezone" : "America/Los_Angeles" | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -904,16 +676,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "precision" : "SINGLE" | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | }, { | "name" : "NaN_d", | "type" : { @@ -921,16 +684,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "precision" : "DOUBLE" | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -939,12 +693,12 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "name" : "NaN_f", | "count" : 2, | "VALIDITY" : [ 1, 1 ], - | "DATA" : [ 1.2000000476837158, "NaN" ] + | "DATA" : [ 1.2000000476837158, NaN ] | }, { | "name" : "NaN_d", | "count" : 2, | "VALIDITY" : [ 1, 1 ], - | "DATA" : [ "NaN", 1.2 ] + | "DATA" : [ NaN, 1.2 ] | } ] | } ] |} @@ -976,26 +730,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32, | "isSigned" : true | }, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] + | } ] | }, { | "name" : "b_arr", | "nullable" : true, @@ -1010,26 +746,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32, | "isSigned" : true | }, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] + | } ] | }, { | "name" : "c_arr", | "nullable" : true, @@ -1044,26 +762,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32, | "isSigned" : true | }, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] + | } ] | }, { | "name" : "d_arr", | "nullable" : true, @@ -1084,36 +784,9 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32, | "isSigned" : true | }, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | } ] - | } - | } ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 + | "children" : [ ] | } ] - | } + | } ] | } ] | }, | "batches" : [ { @@ -1204,23 +877,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32, | "isSigned" : true | }, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | } ] - | } + | "children" : [ ] + | } ] | }, { | "name" : "b_struct", | "nullable" : true, @@ -1235,23 +893,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32, | "isSigned" : true | }, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | } ] - | } + | "children" : [ ] + | } ] | }, { | "name" : "c_struct", | "nullable" : false, @@ -1266,23 +909,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32, | "isSigned" : true | }, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | } ] - | } + | "children" : [ ] + | } ] | }, { | "name" : "d_struct", | "nullable" : true, @@ -1303,30 +931,9 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32, | "isSigned" : true | }, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | } ] - | } - | } ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 + | "children" : [ ] | } ] - | } + | } ] | } ] | }, | "batches" : [ { @@ -1413,16 +1020,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | }, { | "name" : "b", | "type" : { @@ -1431,16 +1029,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -1471,16 +1060,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | }, { | "name" : "b", | "type" : { @@ -1489,16 +1069,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -1600,16 +1171,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | }, { | "name" : "b_i", | "type" : { @@ -1618,16 +1180,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -1658,16 +1211,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | }, { | "name" : "a_i", | "type" : { @@ -1676,16 +1220,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 068a17bf772e1..acac7ca5fabec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -30,15 +30,14 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("boolean") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("boolean", 0, Long.MaxValue) val vector = ArrowUtils.toArrowField("boolean", BooleanType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableBitVector] + .createVector(allocator).asInstanceOf[BitVector] vector.allocateNew() - val mutator = vector.getMutator() (0 until 10).foreach { i => - mutator.setSafe(i, if (i % 2 == 0) 1 else 0) + vector.setSafe(i, if (i % 2 == 0) 1 else 0) } - mutator.setNull(10) - mutator.setValueCount(11) + vector.setNull(10) + vector.setValueCount(11) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === BooleanType) @@ -59,15 +58,14 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("byte") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("byte", 0, Long.MaxValue) val vector = ArrowUtils.toArrowField("byte", ByteType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableTinyIntVector] + .createVector(allocator).asInstanceOf[TinyIntVector] vector.allocateNew() - val mutator = vector.getMutator() (0 until 10).foreach { i => - mutator.setSafe(i, i.toByte) + vector.setSafe(i, i.toByte) } - mutator.setNull(10) - mutator.setValueCount(11) + vector.setNull(10) + vector.setValueCount(11) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ByteType) @@ -88,15 +86,14 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("short") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("short", 0, Long.MaxValue) val vector = ArrowUtils.toArrowField("short", ShortType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableSmallIntVector] + .createVector(allocator).asInstanceOf[SmallIntVector] vector.allocateNew() - val mutator = vector.getMutator() (0 until 10).foreach { i => - mutator.setSafe(i, i.toShort) + vector.setSafe(i, i.toShort) } - mutator.setNull(10) - mutator.setValueCount(11) + vector.setNull(10) + vector.setValueCount(11) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ShortType) @@ -117,15 +114,14 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("int") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) val vector = ArrowUtils.toArrowField("int", IntegerType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableIntVector] + .createVector(allocator).asInstanceOf[IntVector] vector.allocateNew() - val mutator = vector.getMutator() (0 until 10).foreach { i => - mutator.setSafe(i, i) + vector.setSafe(i, i) } - mutator.setNull(10) - mutator.setValueCount(11) + vector.setNull(10) + vector.setValueCount(11) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === IntegerType) @@ -146,15 +142,14 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("long") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("long", 0, Long.MaxValue) val vector = ArrowUtils.toArrowField("long", LongType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableBigIntVector] + .createVector(allocator).asInstanceOf[BigIntVector] vector.allocateNew() - val mutator = vector.getMutator() (0 until 10).foreach { i => - mutator.setSafe(i, i.toLong) + vector.setSafe(i, i.toLong) } - mutator.setNull(10) - mutator.setValueCount(11) + vector.setNull(10) + vector.setValueCount(11) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === LongType) @@ -175,15 +170,14 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("float") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("float", 0, Long.MaxValue) val vector = ArrowUtils.toArrowField("float", FloatType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableFloat4Vector] + .createVector(allocator).asInstanceOf[Float4Vector] vector.allocateNew() - val mutator = vector.getMutator() (0 until 10).foreach { i => - mutator.setSafe(i, i.toFloat) + vector.setSafe(i, i.toFloat) } - mutator.setNull(10) - mutator.setValueCount(11) + vector.setNull(10) + vector.setValueCount(11) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === FloatType) @@ -204,15 +198,14 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("double") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("double", 0, Long.MaxValue) val vector = ArrowUtils.toArrowField("double", DoubleType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableFloat8Vector] + .createVector(allocator).asInstanceOf[Float8Vector] vector.allocateNew() - val mutator = vector.getMutator() (0 until 10).foreach { i => - mutator.setSafe(i, i.toDouble) + vector.setSafe(i, i.toDouble) } - mutator.setNull(10) - mutator.setValueCount(11) + vector.setNull(10) + vector.setValueCount(11) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === DoubleType) @@ -233,16 +226,15 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("string") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("string", 0, Long.MaxValue) val vector = ArrowUtils.toArrowField("string", StringType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableVarCharVector] + .createVector(allocator).asInstanceOf[VarCharVector] vector.allocateNew() - val mutator = vector.getMutator() (0 until 10).foreach { i => val utf8 = s"str$i".getBytes("utf8") - mutator.setSafe(i, utf8, 0, utf8.length) + vector.setSafe(i, utf8, 0, utf8.length) } - mutator.setNull(10) - mutator.setValueCount(11) + vector.setNull(10) + vector.setValueCount(11) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === StringType) @@ -261,16 +253,15 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("binary") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("binary", 0, Long.MaxValue) val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableVarBinaryVector] + .createVector(allocator).asInstanceOf[VarBinaryVector] vector.allocateNew() - val mutator = vector.getMutator() (0 until 10).foreach { i => val utf8 = s"str$i".getBytes("utf8") - mutator.setSafe(i, utf8, 0, utf8.length) + vector.setSafe(i, utf8, 0, utf8.length) } - mutator.setNull(10) - mutator.setValueCount(11) + vector.setNull(10) + vector.setValueCount(11) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === BinaryType) @@ -291,31 +282,29 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val vector = ArrowUtils.toArrowField("array", ArrayType(IntegerType), nullable = true, null) .createVector(allocator).asInstanceOf[ListVector] vector.allocateNew() - val mutator = vector.getMutator() - val elementVector = vector.getDataVector().asInstanceOf[NullableIntVector] - val elementMutator = elementVector.getMutator() + val elementVector = vector.getDataVector().asInstanceOf[IntVector] // [1, 2] - mutator.startNewValue(0) - elementMutator.setSafe(0, 1) - elementMutator.setSafe(1, 2) - mutator.endValue(0, 2) + vector.startNewValue(0) + elementVector.setSafe(0, 1) + elementVector.setSafe(1, 2) + vector.endValue(0, 2) // [3, null, 5] - mutator.startNewValue(1) - elementMutator.setSafe(2, 3) - elementMutator.setNull(3) - elementMutator.setSafe(4, 5) - mutator.endValue(1, 3) + vector.startNewValue(1) + elementVector.setSafe(2, 3) + elementVector.setNull(3) + elementVector.setSafe(4, 5) + vector.endValue(1, 3) // null // [] - mutator.startNewValue(3) - mutator.endValue(3, 0) + vector.startNewValue(3) + vector.endValue(3, 0) - elementMutator.setValueCount(5) - mutator.setValueCount(4) + elementVector.setValueCount(5) + vector.setValueCount(4) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ArrayType(IntegerType)) @@ -348,38 +337,35 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val vector = ArrowUtils.toArrowField("struct", schema, nullable = true, null) .createVector(allocator).asInstanceOf[NullableMapVector] vector.allocateNew() - val mutator = vector.getMutator() - val intVector = vector.getChildByOrdinal(0).asInstanceOf[NullableIntVector] - val intMutator = intVector.getMutator() - val longVector = vector.getChildByOrdinal(1).asInstanceOf[NullableBigIntVector] - val longMutator = longVector.getMutator() + val intVector = vector.getChildByOrdinal(0).asInstanceOf[IntVector] + val longVector = vector.getChildByOrdinal(1).asInstanceOf[BigIntVector] // (1, 1L) - mutator.setIndexDefined(0) - intMutator.setSafe(0, 1) - longMutator.setSafe(0, 1L) + vector.setIndexDefined(0) + intVector.setSafe(0, 1) + longVector.setSafe(0, 1L) // (2, null) - mutator.setIndexDefined(1) - intMutator.setSafe(1, 2) - longMutator.setNull(1) + vector.setIndexDefined(1) + intVector.setSafe(1, 2) + longVector.setNull(1) // (null, 3L) - mutator.setIndexDefined(2) - intMutator.setNull(2) - longMutator.setSafe(2, 3L) + vector.setIndexDefined(2) + intVector.setNull(2) + longVector.setSafe(2, 3L) // null - mutator.setNull(3) + vector.setNull(3) // (5, 5L) - mutator.setIndexDefined(4) - intMutator.setSafe(4, 5) - longMutator.setSafe(4, 5L) + vector.setIndexDefined(4) + intVector.setSafe(4, 5) + longVector.setSafe(4, 5L) - intMutator.setValueCount(5) - longMutator.setValueCount(5) - mutator.setValueCount(5) + intVector.setValueCount(5) + longVector.setValueCount(5) + vector.setValueCount(5) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === schema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 0ae4f2d117609..e5d1bc24bc8e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Random -import org.apache.arrow.vector.NullableIntVector +import org.apache.arrow.vector.IntVector import org.apache.spark.SparkFunSuite import org.apache.spark.memory.MemoryMode @@ -1151,22 +1151,20 @@ class ColumnarBatchSuite extends SparkFunSuite { test("create columnar batch from Arrow column vectors") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableIntVector] + .createVector(allocator).asInstanceOf[IntVector] vector1.allocateNew() - val mutator1 = vector1.getMutator() val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableIntVector] + .createVector(allocator).asInstanceOf[IntVector] vector2.allocateNew() - val mutator2 = vector2.getMutator() (0 until 10).foreach { i => - mutator1.setSafe(i, i) - mutator2.setSafe(i + 1, i) + vector1.setSafe(i, i) + vector2.setSafe(i + 1, i) } - mutator1.setNull(10) - mutator1.setValueCount(11) - mutator2.setNull(0) - mutator2.setValueCount(11) + vector1.setNull(10) + vector1.setValueCount(11) + vector2.setNull(0) + vector2.setValueCount(11) val columnVectors = Seq(new ArrowColumnVector(vector1), new ArrowColumnVector(vector2))