diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 225c21fb18b07..ea7af01121a1b 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -620,6 +620,8 @@ private[serializer] object KryoSerializer { "org.apache.spark.sql.columnar.CachedBatchSerializer", "org.apache.spark.sql.columnar.SimpleMetricsCachedBatchSerializer", "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer", + "org.apache.spark.sql.execution.columnar.ArrowCachedBatch", + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer", "org.apache.spark.ml.attribute.Attribute", "org.apache.spark.ml.attribute.AttributeGroup", diff --git a/docs/_data/menu-sql.yaml b/docs/_data/menu-sql.yaml index 8b3c38523c07b..a100f0ccb1ce3 100644 --- a/docs/_data/menu-sql.yaml +++ b/docs/_data/menu-sql.yaml @@ -80,6 +80,8 @@ subitems: - text: Caching Data url: sql-performance-tuning.html#caching-data + - text: Arrow Cache Format + url: sql-arrow-cache-format.html - text: Tuning Partitions url: sql-performance-tuning.html#tuning-partitions - text: Leveraging Statistics diff --git a/docs/sql-arrow-cache-format.md b/docs/sql-arrow-cache-format.md new file mode 100644 index 0000000000000..86f6063b45754 --- /dev/null +++ b/docs/sql-arrow-cache-format.md @@ -0,0 +1,378 @@ +--- +layout: global +title: Apache Arrow Cache Format +displayTitle: Apache Arrow Cache Format +license: | + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--- + +## Overview + +Apache Spark supports using Apache Arrow as an alternative cache format for in-memory Dataset caching. This format provides improved performance for certain workloads, especially when working with columnar data sources like Parquet and ORC. + +## Benefits + +The Arrow cache format offers several advantages over the default cache format: + +- **Zero-copy reads** when input is already in Arrow format (e.g., Arrow-based data sources, re-caching Arrow cached data) +- **Better filter pushdown** with min/max statistics for partition pruning +- **Off-heap memory management** via Arrow allocators +- **Efficient compression** with zstd and lz4 codecs +- **Arrow ecosystem interoperability** for data sharing + +**Note**: Spark's built-in Parquet/ORC readers use internal column vectors (`OnHeapColumnVector`/`OffHeapColumnVector`), not Arrow format, so they don't benefit from zero-copy optimization. + +## Configuration + +`spark.sql.cache.serializer` is a static SQL configuration, so it must be set when the +SparkSession is built and cannot be changed on a running session (`spark.conf.set` rejects static +keys with `CANNOT_MODIFY_CONFIG`): + +```scala +val spark = SparkSession.builder() + .appName("MyApp") + .config("spark.sql.cache.serializer", + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer") + .getOrCreate() +``` + +**Note**: This config selects the cache serializer for the whole session; once set, this +serializer handles every cached relation. There is no automatic per-relation fallback to another +cache serializer based on the data types involved (see +[Supported Data Types](#supported-data-types) for how unsupported types are handled). The chosen +serializer is also cached process-wide on first use, so switching cache formats within a JVM that +has already materialized a cache requires a fresh JVM (see +[Migration from Default Cache](#migration-from-default-cache)). + +## Usage + +Once configured, use cache operations as normal: + +```scala +// Cache a DataFrame +val df = spark.read.parquet("data.parquet") +df.cache() + +// Use cached data +df.filter("age > 30").count() + +// Uncache when done +df.unpersist() +``` + +## Compression + +Arrow cache supports multiple compression codecs. Configure compression with: + +```scala +spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") +``` + +Available options: +- `none` - No compression (fastest, largest size, **default**) +- `lz4` - LZ4 compression (fast, good compression) +- `zstd` - Zstandard compression (slower, best compression) + +For zstd, you can also configure the compression level. Positive values (up to 22) give better +compression but slower speed; negative values give ultra-fast compression with lower ratios: + +```scala +spark.conf.set("spark.sql.execution.arrow.compression.zstd.level", "3") // Default: 3 +``` + +## Vectorized Reader + +Enable vectorized reading for better performance with primitive types: + +```scala +spark.conf.set("spark.sql.inMemoryColumnarStorage.enableVectorizedReader", "true") +``` + +When enabled, cached data is read as columnar batches instead of rows, which can significantly improve performance for columnar operations. + +## Performance Characteristics + +In our benchmarks, the Arrow cache format performs best on the following workloads. Actual +results depend on data types, compression settings, and hardware, and the default cache format +can be faster in some cases (for example, with higher compression levels): + +1. **Filter-Heavy Workloads**: Queries with selective filters benefit from min/max statistics. +2. **Columnar Operations**: Aggregations and projections on cached data benefit from the Arrow format. +3. **Parquet/ORC Caching**: Arrow's batch processing helps even without the zero-copy path. +4. **Re-caching with Column Projection**: Dropping columns from Arrow-cached data preserves the + `ArrowColumnVector` format, enabling true zero-copy extraction and the largest gains. + +### Benchmark Results + +The numbers below are illustrative results from one run on an Apple M4 Max (OpenJDK 21.0.8) and +will vary with hardware, JDK, and compression settings. They are not a guarantee. For the +authoritative, regularly regenerated numbers, see +`sql/core/benchmarks/ArrowCacheBenchmark-jdk21-results.txt` and the `ArrowCacheBenchmark` suite. + +| Workload | Default Cache | Arrow Cache | Speedup | +|----------|--------------|-------------|---------| +| Write + Read (5M rows, 3 primitive columns) | 153.7 ns/row | 74.2 ns/row | **~2X faster** | +| Cache then filter (5M rows) | 100.1 ns/row | 70.8 ns/row | **~1.4X faster** | +| Columnar input from Parquet (2M rows, 3 primitive columns) | 195.3 ns/row | 113.1 ns/row | **~1.7X faster** | +| Re-cache with zero-copy (2M rows, 2 columns) | 123.3 ns/row | 38.5 ns/row | **~3.2X faster** | + +**Notes**: +- **Write + Read**: Significant improvement from efficient Arrow serialization and vectorized operations +- **Cache then filter**: This measures end-to-end cache build plus a filtered scan, comparing the two cache formats. Both formats collect min/max statistics and can prune batches, so the difference reflects overall cache+scan throughput rather than pruning unique to Arrow +- **Parquet caching**: Shows improvement despite Spark's Parquet reader producing `OnHeapColumnVector`/`OffHeapColumnVector` rather than `ArrowColumnVector`, due to Arrow's efficient batch processing +- **Re-cache with zero-copy**: When caching a subset of columns from Arrow-cached data (e.g., `df.drop("column")`), the remaining columns preserve their `ArrowColumnVector` format, enabling true zero-copy extraction and achieving the best performance +- **Zero-copy benefits** only apply when input is already `ArrowColumnVector` (e.g., Python Arrow sources, re-caching Arrow cached data with column projection) + +## Supported Data Types + +Arrow cache supports the following data types: + +### Primitive Types +- BooleanType +- ByteType, ShortType, IntegerType, LongType +- FloatType, DoubleType +- DecimalType (all precision/scale combinations) +- NullType + +### Temporal Types +- DateType +- TimestampType +- TimestampNTZType +- TimeType + +### Interval Types +- YearMonthIntervalType +- DayTimeIntervalType +- CalendarIntervalType (see the value-range note below) + +`CalendarIntervalType` is stored through Arrow's nanosecond-based interval representation, so its +microsecond component must fit within +/-(`Long.MaxValue` / 1000). Caching a value beyond that range +fails with a clear error rather than silently corrupting the value. The default cache serializer +does not have this restriction. + +### String and Binary +- StringType (including collated strings) +- BinaryType + +### Complex Types +- ArrayType +- StructType +- MapType +- Nested combinations of the above + +### Other Types +- VariantType +- GeometryType, GeographyType +- User-defined types (UDTs) whose underlying representation is itself supported + +### Unsupported Types + +Arrow cache covers every type the default cache serializer supports, plus some it +does not (for example geometry and geography). Types that Arrow cannot represent +(such as `ObjectType`) are not silently dropped or routed to a different cache +serializer: there is no per-type fallback, because the cache serializer is chosen +once via the static `spark.sql.cache.serializer` configuration and then handles +every cached relation. Attempting to cache an unsupported type fails with an +`UNSUPPORTED_DATATYPE` error when the cache is materialized. + +## Statistics and Filter Pushdown + +Arrow cache automatically collects min/max statistics for the following types: +- Boolean +- Numeric types (Byte, Short, Int, Long, Float, Double) +- Decimal +- Date, Timestamp, and Timestamp without time zone (TIMESTAMP_NTZ) +- Time +- Year-month and day-time intervals +- String (using collation-aware comparison for collated strings) + +Other types (Binary, Variant, calendar intervals, and complex types such as +Array/Struct/Map) are cached but do not contribute min/max bounds, so they only +record null counts and sizes. + +These statistics enable partition pruning when filtering: + +```scala +val df = spark.range(10000000).cache() + +// This filter can skip batches using min/max statistics +df.filter("id > 5000000").count() +``` + +## Memory Management + +The cached data itself lives on the JVM heap, not off-heap. Each cached batch is stored as a +serialized Arrow IPC byte array (`Array[Byte]`), and the default `Dataset.cache()` storage level is +the deserialized `MEMORY_AND_DISK`, so those bytes are retained as ordinary heap objects (and spill +to disk under memory pressure). Arrow's off-heap allocators are used only for the transient +`VectorSchemaRoot`s created while encoding a batch for caching and while decoding a batch on read; +these are released as soon as the encode/decode completes. + +**Sizing implications**: +- Size the JVM heap (executor memory) for the cached data, since that is where it resides. This is + the main knob for cache capacity. +- `spark.executor.memoryOverhead` covers the transient off-heap encode/decode buffers, which are + proportional to a single batch, not to the total cached size. It generally does not need to grow + with the size of the cache. +- Arrow cache is often **more memory-efficient** than the default cache for the heap-resident bytes: + efficient zstd/lz4 compression, a compact columnar layout without per-value Java object overhead, + and better compression ratios for strings and complex types. + +**Memory Cleanup**: +- The transient off-heap encode/decode roots are released when each task completes. +- The heap-resident cached bytes are released when the DataFrame is unpersisted or evicted, like any + other cached block. + +You can monitor cache block sizes through the Storage tab in the Spark UI. + +## Limitations and Considerations + +1. **Static Configuration**: Cache serializer must be set before SparkSession creation +2. **Memory Overhead**: Arrow format has small per-batch overhead +3. **Compatibility**: Cannot mix cache formats - recache needed when switching +4. **Compression Trade-off**: Higher compression = lower memory but slower reads + +## Migration from Default Cache + +The cache serializer is resolved from `spark.sql.cache.serializer` only on first use and is then +held in a process-wide field that is not reset when a SparkSession stops. As a result, **switching +cache formats requires a fresh JVM** once any cache has been materialized -- stopping and +rebuilding the SparkSession in the same process keeps using the originally resolved serializer. + +To migrate from the default cache to Arrow cache: + +1. **Start a new JVM / driver process** (a brand-new Spark application). +2. **Build the SparkSession with the Arrow serializer**: + ```scala + val spark = SparkSession.builder() + .config("spark.sql.cache.serializer", + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer") + .getOrCreate() + ``` +3. **Cache your DataFrames** as usual. + +**Note**: Cache data is never shared across formats; each application caches in whichever format +its serializer produces. + +## Troubleshooting + +### Out of Memory Errors + +If you encounter OOM errors with Arrow cache: + +1. Reduce batch size: + ```scala + spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "5000") // Default: 10000 + ``` + +2. Enable compression: + ```scala + spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") + ``` + +3. Reduce compression level: + ```scala + spark.conf.set("spark.sql.execution.arrow.compression.zstd.level", "1") + ``` + +### Slow Performance + +If Arrow cache is slower than expected: + +1. Enable vectorized reader: + ```scala + spark.conf.set("spark.sql.inMemoryColumnarStorage.enableVectorizedReader", "true") + ``` + +2. Reduce or disable compression (decompression is part of every read): + ```scala + spark.conf.set("spark.sql.execution.arrow.compression.zstd.level", "1") // faster, less ratio + // or, for read-heavy workloads where memory is not the constraint: + spark.conf.set("spark.sql.execution.arrow.compression.codec", "none") + ``` + Note: `lz4` is not recommended unless the native LZ4 library is on the classpath. Without it, + Arrow falls back to the pure-Java Commons Compress LZ4, which is far slower than zstd. + +3. Increase batch size (if memory allows): + ```scala + spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "20000") + ``` + +## Configuration Reference + +| Configuration | Default | Description | +|---------------|---------|-------------| +| `spark.sql.cache.serializer` | DefaultCachedBatchSerializer | Cache format serializer class | +| `spark.sql.execution.arrow.compression.codec` | `none` | Compression codec (none, lz4, zstd) | +| `spark.sql.execution.arrow.compression.zstd.level` | `3` | Zstd compression level (negative = faster, up to 22) | +| `spark.sql.execution.arrow.maxRecordsPerBatch` | `10000` | Maximum rows per Arrow batch | +| `spark.sql.inMemoryColumnarStorage.enableVectorizedReader` | `true` | Enable vectorized cache reading | + +## Example: Complete Application + +```scala +import org.apache.spark.sql.SparkSession + +object ArrowCacheExample { + def main(args: Array[String]): Unit = { + // Create SparkSession with Arrow cache + val spark = SparkSession.builder() + .appName("ArrowCacheExample") + .master("local[*]") + .config("spark.sql.cache.serializer", + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer") + .config("spark.sql.execution.arrow.compression.codec", "zstd") + .config("spark.sql.inMemoryColumnarStorage.enableVectorizedReader", "true") + .getOrCreate() + + try { + // Read columnar data source + val df = spark.read.parquet("large_dataset.parquet") + + // Cache with Arrow format + df.cache() + + // Queries benefit from zero-copy reads and statistics + val result1 = df.filter("age > 30").select("name", "age").count() + println(s"Filtered count: $result1") + + val result2 = df.groupBy("country").agg(sum("sales")).collect() + println(s"Aggregation result: ${result2.mkString(", ")}") + + // Uncache when done + df.unpersist() + + } finally { + spark.stop() + } + } +} +``` + +## Best Practices + +1. **Use with Columnar Sources**: Maximum benefit with Parquet/ORC +2. **Enable Statistics**: Let Arrow cache collect min/max for filter pushdown +3. **Monitor Memory**: Watch cache block sizes in the Spark UI Storage tab; the cached data lives on the JVM heap +4. **Test First**: Benchmark your workload before production deployment +5. **Compression**: Start with `none` for read-heavy workloads, or `zstd` (default level 3) when memory matters; avoid `lz4` unless the native LZ4 library is on the classpath +6. **Vectorization**: Enable vectorized reader for primitive-heavy workloads + +## Further Reading + +- [Apache Arrow Project](https://arrow.apache.org/) +- [Spark Caching Documentation](https://spark.apache.org/docs/latest/sql-performance-tuning.html#caching-data-in-memory) +- [Arrow IPC Format](https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format) diff --git a/docs/sql-performance-tuning.md b/docs/sql-performance-tuning.md index ebae89cbe5e0f..c7da8d11bf04a 100644 --- a/docs/sql-performance-tuning.md +++ b/docs/sql-performance-tuning.md @@ -32,6 +32,10 @@ memory usage and GC pressure. You can call `spark.catalog.uncacheTable("tableNam To list relations cached with an explicit name, use `spark.catalog.listCachedTables()`. Entries cached only via `Dataset.cache()` without a name are not included. +Spark supports two cache formats: +- **Default cache format**: The standard in-memory columnar cache (used by default). +- **Arrow cache format**: An Apache Arrow-based cache that can improve read performance for columnar workloads and enables Arrow ecosystem interoperability. See [Arrow Cache Format documentation](sql-arrow-cache-format.html) for details and configuration. + Configuration of in-memory caching can be done via `spark.conf.set` or by running `SET key=value` commands using SQL. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index a06a77d9d1139..ed3b92c835771 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -38,6 +38,50 @@ private[sql] object ArrowUtils { // todo: support more types. + /** + * Check if a Spark DataType is supported by Arrow. This recursively checks complex types + * (Array, Struct, Map). + * + * Note: This checks compatibility with toArrowField(), not toArrowType(). Types like + * GeometryType, GeographyType, and VariantType are not supported by toArrowType() (which only + * handles primitive Arrow types), but ARE supported by toArrowField() which converts them to + * Arrow Struct representations with metadata. Since Arrow cache uses toArrowField() via + * toArrowSchema() to create the schema, these types are supported. + */ + def isSupportedByArrow(dt: DataType): Boolean = { + dt match { + // Primitive types + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + _: StringType | BinaryType | NullType => + true + + // Decimal + case _: DecimalType => true + + // Temporal types + case DateType | TimestampType | TimestampNTZType | _: TimeType => true + + // Interval types + case _: YearMonthIntervalType | _: DayTimeIntervalType | CalendarIntervalType => true + + // Complex types - recursively check element types + case ArrayType(elementType, _) => isSupportedByArrow(elementType) + case StructType(fields) => fields.forall(f => isSupportedByArrow(f.dataType)) + case MapType(keyType, valueType, _) => + isSupportedByArrow(keyType) && isSupportedByArrow(valueType) + + // Special types + // Note: These are not in toArrowType(), but are handled by toArrowField() + case udt: UserDefinedType[_] => isSupportedByArrow(udt.sqlType) + case _: GeometryType => true // Converted to Struct with srid + wkb fields + case _: GeographyType => true // Converted to Struct with srid + wkb fields + case _: VariantType => true // Converted to Struct with value + metadata fields + + // Unsupported types + case _ => false + } + } + /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ def toArrowType(dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType = TypeApiOps(dt) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6776f88ed1ef8..21d33f2613dd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4767,6 +4767,18 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ARROW_CACHE_PREFETCH_ENABLED = + buildConf("spark.sql.execution.arrow.cache.prefetch.enabled") + .doc("When true, Arrow cache read path prefetches and decompresses the next batch " + + "in a background thread while the current batch is being consumed. This can " + + "significantly improve read performance for compressed Arrow caches (e.g., ZSTD) " + + "by overlapping decompression with consumption. Increases memory usage by up to " + + "one additional batch worth of Arrow vectors.") + .version("4.3.0") + .withBindingPolicy(ConfigBindingPolicy.SESSION) + .booleanConf + .createWithDefault(false) + val ARROW_TRANSFORM_WITH_STATE_IN_PYSPARK_MAX_STATE_RECORDS_PER_BATCH = buildConf("spark.sql.execution.arrow.transformWithStateInPySpark.maxStateRecordsPerBatch") .doc("When using TransformWithState in PySpark (both Python Row and Pandas), limit " + @@ -8521,6 +8533,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def arrowPySparkUDFColumnarInputEnabled: Boolean = getConf(ARROW_PYSPARK_UDF_COLUMNAR_INPUT_ENABLED) + def arrowCachePrefetchEnabled: Boolean = getConf(ARROW_CACHE_PREFETCH_ENABLED) + def arrowTransformWithStateInPySparkMaxStateRecordsPerBatch: Int = getConf(ARROW_TRANSFORM_WITH_STATE_IN_PYSPARK_MAX_STATE_RECORDS_PER_BATCH) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index 9855244c968ce..658ab5f5312d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -171,7 +171,10 @@ object StaticSQLConf { "org.apache.spark.sql.columnar.CachedBatchSerializer. It will be used to " + "translate SQL data into a format that can more efficiently be cached. The underlying " + "API is subject to change so use with caution. Multiple classes cannot be specified. " + - "The class must have a no-arg constructor.") + "The class must have a no-arg constructor. Available implementations include: " + + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer (default) and " + + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer (Arrow format with " + + "zero-copy columnar reads and better Arrow ecosystem interoperability).") .version("3.1.0") .stringConf .createWithDefault("org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") diff --git a/sql/core/benchmarks/ArrowCacheBenchmark-jdk21-results.txt b/sql/core/benchmarks/ArrowCacheBenchmark-jdk21-results.txt new file mode 100644 index 0000000000000..971d035b7b6c0 --- /dev/null +++ b/sql/core/benchmarks/ArrowCacheBenchmark-jdk21-results.txt @@ -0,0 +1,85 @@ +================================================================================================ +Arrow Cache vs Default Cache +================================================================================================ + +================================================================================================ +Cache primitive types +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +Cache 5M rows with primitives: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +Default cache - write + read 1854 1922 97 2.7 370.8 1.0X +Default cache - write + read (uncompressed) 1159 1165 8 4.3 231.8 1.6X +Arrow cache - write + read 1300 1315 21 3.8 260.0 1.4X +Arrow cache - write + read (zstd level -1) 1808 1811 4 2.8 361.6 1.0X +Arrow cache - write + read (zstd level 1) 1814 1830 23 2.8 362.8 1.0X +Arrow cache - write + read (zstd level 3) 1902 1929 39 2.6 380.4 1.0X + + +================================================================================================ +Cache then filter +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +Cache 5M rows, then filter: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Default cache - filter 1662 1683 29 3.0 332.5 1.0X +Default cache - filter (uncompressed) 1312 1312 0 3.8 262.4 1.3X +Arrow cache - filter 1447 1462 21 3.5 289.4 1.1X +Arrow cache - filter (zstd level -1) 1729 1757 40 2.9 345.8 1.0X +Arrow cache - filter (zstd level 1) 1787 1799 17 2.8 357.3 0.9X +Arrow cache - filter (zstd level 3) 1951 1955 5 2.6 390.3 0.9X + + +================================================================================================ +Cache columnar input (Parquet) +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +Cache 2M rows from Parquet: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------- +Default cache - columnar input 1545 1619 104 1.3 772.7 1.0X +Default cache - columnar input (uncompressed) 1313 1336 33 1.5 656.4 1.2X +Arrow cache - columnar input 1353 1378 35 1.5 676.7 1.1X +Arrow cache - columnar input (zstd level -1) 1535 1573 54 1.3 767.6 1.0X +Arrow cache - columnar input (zstd level 1) 1619 1622 5 1.2 809.6 1.0X +Arrow cache - columnar input (zstd level 3) 1708 1709 2 1.2 853.8 0.9X + + +================================================================================================ +Re-cache Arrow cached data (zero-copy test) +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +Re-cache 2M rows (zero-copy): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------- +Default cache - cache a cached DF 411 428 20 4.9 205.7 1.0X +Default cache - cache a cached DF (uncompressed) 191 210 26 10.5 95.7 2.2X +Arrow cache - cache a cached DF (zero-copy) 137 156 24 14.6 68.4 3.0X +Arrow cache - cache a cached DF (zstd level -1) 327 343 18 6.1 163.3 1.3X +Arrow cache - cache a cached DF (zstd level 1) 338 341 3 5.9 168.8 1.2X +Arrow cache - cache a cached DF (zstd level 3) 352 357 3 5.7 176.2 1.2X + + +================================================================================================ +Cache with column pruning (select 1 of 20 columns) +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +Cache 5M rows, select 1 column: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------- +Default cache - select 1 of 20 columns 10855 11142 406 0.5 2171.0 1.0X +Default cache - select 1 of 20 (uncompressed) 4135 4149 20 1.2 827.0 2.6X +Arrow cache - select 1 of 20 5179 5280 144 1.0 1035.8 2.1X +Arrow cache - select 1 of 20 (zstd level -1) 9258 9283 35 0.5 1851.7 1.2X +Arrow cache - select 1 of 20 (zstd level 1) 9437 9603 234 0.5 1887.4 1.2X +Arrow cache - select 1 of 20 (zstd level 3) 9778 9794 23 0.5 1955.5 1.1X + + + diff --git a/sql/core/benchmarks/ArrowCacheBenchmark-jdk25-results.txt b/sql/core/benchmarks/ArrowCacheBenchmark-jdk25-results.txt new file mode 100644 index 0000000000000..6deddba5300ec --- /dev/null +++ b/sql/core/benchmarks/ArrowCacheBenchmark-jdk25-results.txt @@ -0,0 +1,85 @@ +================================================================================================ +Arrow Cache vs Default Cache +================================================================================================ + +================================================================================================ +Cache primitive types +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +Cache 5M rows with primitives: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +Default cache - write + read 1686 1723 53 3.0 337.2 1.0X +Default cache - write + read (uncompressed) 1045 1065 27 4.8 209.1 1.6X +Arrow cache - write + read 1268 1305 53 3.9 253.6 1.3X +Arrow cache - write + read (zstd level -1) 1724 1725 1 2.9 344.8 1.0X +Arrow cache - write + read (zstd level 1) 1770 1794 34 2.8 354.0 1.0X +Arrow cache - write + read (zstd level 3) 1857 1893 50 2.7 371.4 0.9X + + +================================================================================================ +Cache then filter +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +Cache 5M rows, then filter: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Default cache - filter 1426 1432 8 3.5 285.3 1.0X +Default cache - filter (uncompressed) 1252 1274 31 4.0 250.4 1.1X +Arrow cache - filter 1289 1295 8 3.9 257.8 1.1X +Arrow cache - filter (zstd level -1) 1712 1716 7 2.9 342.4 0.8X +Arrow cache - filter (zstd level 1) 1747 1759 16 2.9 349.5 0.8X +Arrow cache - filter (zstd level 3) 1812 1848 50 2.8 362.4 0.8X + + +================================================================================================ +Cache columnar input (Parquet) +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +Cache 2M rows from Parquet: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------- +Default cache - columnar input 1461 1486 35 1.4 730.6 1.0X +Default cache - columnar input (uncompressed) 1219 1227 12 1.6 609.3 1.2X +Arrow cache - columnar input 1253 1273 27 1.6 626.7 1.2X +Arrow cache - columnar input (zstd level -1) 1448 1460 17 1.4 723.8 1.0X +Arrow cache - columnar input (zstd level 1) 1504 1504 0 1.3 752.0 1.0X +Arrow cache - columnar input (zstd level 3) 1578 1587 13 1.3 788.9 0.9X + + +================================================================================================ +Re-cache Arrow cached data (zero-copy test) +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +Re-cache 2M rows (zero-copy): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------- +Default cache - cache a cached DF 386 409 28 5.2 193.1 1.0X +Default cache - cache a cached DF (uncompressed) 194 217 26 10.3 96.8 2.0X +Arrow cache - cache a cached DF (zero-copy) 132 144 10 15.2 65.9 2.9X +Arrow cache - cache a cached DF (zstd level -1) 321 324 7 6.2 160.3 1.2X +Arrow cache - cache a cached DF (zstd level 1) 333 341 7 6.0 166.7 1.2X +Arrow cache - cache a cached DF (zstd level 3) 350 356 12 5.7 174.8 1.1X + + +================================================================================================ +Cache with column pruning (select 1 of 20 columns) +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +Cache 5M rows, select 1 column: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------- +Default cache - select 1 of 20 columns 9310 9426 164 0.5 1862.0 1.0X +Default cache - select 1 of 20 (uncompressed) 3929 3994 92 1.3 785.7 2.4X +Arrow cache - select 1 of 20 5150 5225 106 1.0 1030.0 1.8X +Arrow cache - select 1 of 20 (zstd level -1) 9265 9376 156 0.5 1853.1 1.0X +Arrow cache - select 1 of 20 (zstd level 1) 9296 9351 78 0.5 1859.3 1.0X +Arrow cache - select 1 of 20 (zstd level 3) 9970 9982 18 0.5 1994.0 0.9X + + + diff --git a/sql/core/benchmarks/ArrowCacheBenchmark-results.txt b/sql/core/benchmarks/ArrowCacheBenchmark-results.txt new file mode 100644 index 0000000000000..01edba6d9bf05 --- /dev/null +++ b/sql/core/benchmarks/ArrowCacheBenchmark-results.txt @@ -0,0 +1,85 @@ +================================================================================================ +Arrow Cache vs Default Cache +================================================================================================ + +================================================================================================ +Cache primitive types +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1018-azure +AMD EPYC 9V74 80-Core Processor +Cache 5M rows with primitives: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +Default cache - write + read 1765 1794 41 2.8 353.0 1.0X +Default cache - write + read (uncompressed) 1154 1168 20 4.3 230.9 1.5X +Arrow cache - write + read 1279 1290 15 3.9 255.9 1.4X +Arrow cache - write + read (zstd level -1) 1736 1736 1 2.9 347.1 1.0X +Arrow cache - write + read (zstd level 1) 1795 1815 28 2.8 359.0 1.0X +Arrow cache - write + read (zstd level 3) 1858 1862 6 2.7 371.5 1.0X + + +================================================================================================ +Cache then filter +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1018-azure +AMD EPYC 9V74 80-Core Processor +Cache 5M rows, then filter: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Default cache - filter 1454 1478 33 3.4 290.9 1.0X +Default cache - filter (uncompressed) 1312 1315 4 3.8 262.5 1.1X +Arrow cache - filter 1391 1394 4 3.6 278.2 1.0X +Arrow cache - filter (zstd level -1) 1737 1751 21 2.9 347.4 0.8X +Arrow cache - filter (zstd level 1) 1808 1814 7 2.8 361.7 0.8X +Arrow cache - filter (zstd level 3) 1874 1874 0 2.7 374.8 0.8X + + +================================================================================================ +Cache columnar input (Parquet) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1018-azure +AMD EPYC 9V74 80-Core Processor +Cache 2M rows from Parquet: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------- +Default cache - columnar input 1492 1498 8 1.3 746.1 1.0X +Default cache - columnar input (uncompressed) 1263 1264 2 1.6 631.4 1.2X +Arrow cache - columnar input 1318 1318 0 1.5 658.9 1.1X +Arrow cache - columnar input (zstd level -1) 1541 1542 1 1.3 770.7 1.0X +Arrow cache - columnar input (zstd level 1) 1556 1565 12 1.3 778.1 1.0X +Arrow cache - columnar input (zstd level 3) 1621 1625 5 1.2 810.5 0.9X + + +================================================================================================ +Re-cache Arrow cached data (zero-copy test) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1018-azure +AMD EPYC 9V74 80-Core Processor +Re-cache 2M rows (zero-copy): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------- +Default cache - cache a cached DF 402 413 11 5.0 201.2 1.0X +Default cache - cache a cached DF (uncompressed) 208 225 28 9.6 103.9 1.9X +Arrow cache - cache a cached DF (zero-copy) 144 155 17 13.9 72.2 2.8X +Arrow cache - cache a cached DF (zstd level -1) 343 346 4 5.8 171.5 1.2X +Arrow cache - cache a cached DF (zstd level 1) 356 367 15 5.6 178.0 1.1X +Arrow cache - cache a cached DF (zstd level 3) 364 366 1 5.5 181.9 1.1X + + +================================================================================================ +Cache with column pruning (select 1 of 20 columns) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1018-azure +AMD EPYC 9V74 80-Core Processor +Cache 5M rows, select 1 column: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------- +Default cache - select 1 of 20 columns 10187 10276 126 0.5 2037.4 1.0X +Default cache - select 1 of 20 (uncompressed) 4115 4129 19 1.2 823.1 2.5X +Arrow cache - select 1 of 20 4927 4931 6 1.0 985.4 2.1X +Arrow cache - select 1 of 20 (zstd level -1) 9043 9098 77 0.6 1808.6 1.1X +Arrow cache - select 1 of 20 (zstd level 1) 9139 9216 109 0.5 1827.9 1.1X +Arrow cache - select 1 of 20 (zstd level 3) 9420 9453 47 0.5 1884.0 1.1X + + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatch.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatch.scala new file mode 100644 index 0000000000000..a21def96afb2e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatch.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.columnar.SimpleMetricsCachedBatch + +/** + * A [[SimpleMetricsCachedBatch]] implementation that stores Arrow RecordBatch data + * in Apache Arrow IPC streaming format. + * + * The batch contains: + * - `numRows`: Number of rows in this batch + * - `arrowData`: Serialized Arrow RecordBatch in IPC streaming format (with optional compression) + * - `stats`: Per-column statistics for partition pruning (upperBound, lowerBound, nullCount, etc.) + * + * This format enables: + * - Zero-copy columnar reads when output is ColumnarBatch with ArrowColumnVector + * - Efficient interoperability with Arrow ecosystem + * - Off-heap memory management via Arrow allocators + * - Built-in compression support (zstd, lz4) at Arrow level + * + * @param numRows Number of rows in this cached batch + * @param arrowData Serialized Arrow RecordBatch in IPC streaming format + * @param stats Per-column statistics as InternalRow (5 fields per column: + * upperBound, lowerBound, nullCount, rowCount, sizeInBytes) + */ +case class ArrowCachedBatch( + numRows: Int, + arrowData: Array[Byte], + stats: InternalRow) extends SimpleMetricsCachedBatch diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializer.scala new file mode 100644 index 0000000000000..4d3b44c2b3d44 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializer.scala @@ -0,0 +1,1533 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.channels.Channels + +import scala.jdk.CollectionConverters._ + +import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec} +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec} +import org.apache.arrow.vector.ipc.{ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.columnar.{CachedBatch, SimpleMetricsCachedBatchSerializer} +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +/** + * A [[CachedBatchSerializer]] that uses Apache Arrow as the cache format. + * + * This serializer: + * - Supports both row-based (InternalRow) and columnar (ColumnarBatch) input + * - Stores data in Arrow IPC streaming format with optional compression (zstd/lz4) + * - Enables zero-copy columnar reads when output is ColumnarBatch + * - Uses off-heap memory via Arrow allocators + * - Collects per-column statistics for partition pruning + * - Provides efficient interoperability with Arrow ecosystem + * + * Configuration options: + * - spark.sql.cache.serializer: Set to this class name to enable + * - spark.sql.execution.arrow.maxRecordsPerBatch: Max rows per cached batch + * - spark.sql.execution.arrow.compression.codec: Compression (none/zstd/lz4) + * - spark.sql.inMemoryColumnarStorage.enableVectorizedReader: Enable columnar output + */ +class ArrowCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { + + // supportsColumnarInput selects the columnar-vs-row input path; it does not gate which schemas + // this serializer accepts. The cache framework has no per-type fallback to another serializer + // (whatever spark.sql.cache.serializer selects handles every cached relation), so returning + // false here only routes input through convertInternalRowToCachedBatch, which is still this + // serializer. Type support is enforced once per partition by checkSupportedSchema below; the + // only real precondition for columnar input is that the plan can produce columnar output, which + // InMemoryRelation already checks via cachedPlan.supportsColumnar before calling this. + override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = true + + override def convertInternalRowToCachedBatch( + input: RDD[InternalRow], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + ArrowCachedBatchSerializer.checkSupportedSchema(schema) + // Capture config values on driver before RDD transformation + val sparkSchema = DataTypeUtils.fromAttributes(schema) + val maxRecordsPerBatch = conf.arrowMaxRecordsPerBatch + val maxBytesPerBatch = conf.arrowMaxBytesPerBatch + val timeZoneId = conf.sessionLocalTimeZone + val compressionCodecName = conf.arrowCompressionCodec + val compressionLevel = conf.arrowZstdCompressionLevel + + input.mapPartitionsInternal { rowIterator => + new InternalRowToArrowCachedBatchIterator( + rowIterator, + schema, + sparkSchema, + maxRecordsPerBatch, + maxBytesPerBatch, + timeZoneId, + compressionCodecName, + compressionLevel) + } + } + + override def convertColumnarBatchToCachedBatch( + input: RDD[ColumnarBatch], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + ArrowCachedBatchSerializer.checkSupportedSchema(schema) + // Capture config values on driver before RDD transformation + val sparkSchema = DataTypeUtils.fromAttributes(schema) + val timeZoneId = conf.sessionLocalTimeZone + val compressionCodecName = conf.arrowCompressionCodec + val compressionLevel = conf.arrowZstdCompressionLevel + + input.mapPartitionsInternal { batchIterator => + new ColumnarBatchToArrowCachedBatchIterator( + batchIterator, + schema, + sparkSchema, + timeZoneId, + compressionCodecName, + compressionLevel) + } + } + + override def supportsColumnarOutput(schema: StructType): Boolean = { + // Always support columnar output with Arrow + true + } + + override def vectorTypes(attributes: Seq[Attribute], conf: SQLConf): Option[Seq[String]] = { + Option(Seq.fill(attributes.length)(classOf[ArrowColumnVector].getName)) + } + + override def convertCachedBatchToColumnarBatch( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[ColumnarBatch] = { + val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes) + val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes) + val columnIndices = + selectedAttributes.map(a => cacheAttributes.map(o => o.exprId).indexOf(a.exprId)).toArray + // Capture config on driver + val timeZoneId = conf.sessionLocalTimeZone + val prefetchEnabled = conf.arrowCachePrefetchEnabled + + input.mapPartitionsInternal { batchIterator => + new ArrowCachedBatchToColumnarBatchIterator( + batchIterator, + cacheSchema, + selectedSchema, + columnIndices, + timeZoneId, + prefetchEnabled) + } + } + + override def convertCachedBatchToInternalRow( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[InternalRow] = { + val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes) + val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes) + val timeZoneId = conf.sessionLocalTimeZone + + // Calculate column indices for projection + val selectedIndices = selectedAttributes.map { attr => + cacheAttributes.indexWhere(_.exprId == attr.exprId) + }.toArray + + // Check if all selected types can use the fast path. + // Types not handled by ArrowColumnReader must use the fallback path. + val needsFallback = selectedSchema.fields.exists { f => + f.dataType match { + case _: ArrayType | _: StructType | _: MapType => true + case CalendarIntervalType | VariantType | NullType => true + case _: UserDefinedType[_] => true + // Geometry/Geography are represented as an Arrow struct (srid + wkb); the fast-path + // ArrowColumnReader does not handle them, so route them through the fallback. + case _: GeometryType | _: GeographyType => true + case _ => false + } + } + + if (needsFallback) { + // Fall back to columnar-to-row conversion via ColumnarBatch for complex types. + // Use UnsafeProjection to convert ColumnarBatchRow to UnsafeRow. + convertCachedBatchToColumnarBatch(input, cacheAttributes, selectedAttributes, conf) + .mapPartitionsInternal { batchIter => + val toUnsafe = org.apache.spark.sql.catalyst.expressions.UnsafeProjection.create( + selectedSchema) + batchIter.flatMap { batch => + val numRows = batch.numRows() + new Iterator[InternalRow] { + private var rowIdx = 0 + override def hasNext: Boolean = rowIdx < numRows + override def next(): InternalRow = { + val row = batch.getRow(rowIdx) + rowIdx += 1 + toUnsafe(row) + } + } + } + } + } else { + val prefetchEnabled = conf.arrowCachePrefetchEnabled + input.mapPartitionsInternal { batchIterator => + new ArrowCachedBatchToInternalRowIterator( + batchIterator, + cacheSchema, + selectedSchema, + selectedIndices, + timeZoneId, + prefetchEnabled) + } + } + } +} + +/** + * Companion object with shared utility methods for Arrow cache serialization. + */ +private object ArrowCachedBatchSerializer { + + /** + * Run an Arrow write block, translating a CalendarInterval microsecond overflow into a clear + * error. Arrow's IntervalMonthDayNano representation is nanosecond-based, so writing a + * CalendarInterval multiplies its microseconds by 1000 with Math.multiplyExact; Spark allows the + * full Long microsecond domain, so values beyond Long.MaxValue/1000 overflow and otherwise abort + * with an opaque "long overflow" ArithmeticException. The catch is only installed when the schema + * actually contains a CalendarInterval column (hasInterval), so there is no per-row cost and no + * effect on schemas without intervals; the try is entered once per batch, not per row. + */ + def withIntervalOverflowTranslation[T](hasInterval: Boolean)(block: => T): T = { + if (!hasInterval) { + block + } else { + try { + block + } catch { + case e: ArithmeticException => + throw SparkException.internalError( + "Arrow cache cannot represent a CalendarInterval whose microseconds exceed " + + "+/-(Long.MaxValue / 1000): Arrow stores intervals in nanoseconds and the " + + s"conversion overflows. Original error: ${e.getMessage}") + } + } + } + + /** Whether the schema has a top-level CalendarInterval column (the only overflow-prone type). */ + def hasCalendarInterval(schema: Seq[Attribute]): Boolean = + schema.exists(_.dataType == CalendarIntervalType) + + /** + * Fail fast, once per partition on the driver-facing entry points, if any column type cannot be + * represented by the Arrow cache. This is the actual capability gate (supportsColumnarInput only + * chooses the input path). Without it, an unsupported type would otherwise surface as a less + * obvious failure deeper in schema conversion or statistics collection. + */ + def checkSupportedSchema(schema: Seq[Attribute]): Unit = { + schema.find(attr => !ArrowUtils.isSupportedByArrow(attr.dataType)).foreach { attr => + throw SparkException.internalError( + s"Arrow cache does not support column '${attr.name}' of type ${attr.dataType.sql}. " + + "Use the default cache serializer for this data, or cast the column to a supported type.") + } + } + + // scalastyle:off caselocale + def createCompressionCodec( + codecName: String, + compressionLevel: Int): CompressionCodec = { + codecName.toLowerCase match { + case "none" => NoCompressionCodec.INSTANCE + // The codec instance must be constructed directly so that compressionLevel is honored: + // CompressionCodec.Factory.createCodec(codecType) ignores the level and builds a codec at + // the default level. The level only matters on the write side; the read side looks up the + // codec by the type recorded in the IPC message. + case "zstd" => new ZstdCompressionCodec(compressionLevel) + case "lz4" => new Lz4CompressionCodec() + case other => + throw SparkException.internalError( + s"Unsupported Arrow compression codec: $other. Supported values: none, zstd, lz4") + } + } + // scalastyle:on caselocale + + def serializeBatch(batch: ArrowRecordBatch): Array[Byte] = { + val out = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(out)) + MessageSerializer.serialize(writeChannel, batch) + out.toByteArray + } + + /** + * Shut down a prefetch worker during task cleanup without leaking the root it may have produced. + * + * The prefetch worker deserializes the next batch into a fresh [[VectorSchemaRoot]] off-thread. + * If task completion runs while a result is in flight (e.g. a LIMIT consumer stops early), + * cancelling and discarding the future would drop a root that was already (or is about to be) + * produced, and the subsequent `allocator.close()` would fail with "Memory was leaked by query". + * + * This stops accepting new work, waits for the worker to finish so no root is produced after we + * stop looking, then closes any completed result. Always returns null so the caller can null out + * its future reference. Safe to call with a null executor or future. + */ + def drainAndClosePrefetch( + executor: java.util.concurrent.ExecutorService, + future: java.util.concurrent.Future[VectorSchemaRoot]): java.util.concurrent.Future[ + VectorSchemaRoot] = { + // Drain and join the worker uninterruptibly, then close any root it produced, before the + // caller closes the allocator. This runs from a task-completion listener, which can fire with + // the task thread already interrupted (e.g. a killed task). If we let awaitTermination or + // future.get observe the interrupt and bail early, the worker could still be allocating into, + // or have already returned, a root that we then neither join nor close -- and the subsequent + // allocator.close() would race the worker or fail with "Memory was leaked by query". So we + // clear the interrupt for the duration and restore it only at the end. + val wasInterrupted = Thread.interrupted() + try { + if (executor != null) { + executor.shutdown() + var terminated = false + while (!terminated) { + try { + terminated = + executor.awaitTermination(Long.MaxValue, java.util.concurrent.TimeUnit.NANOSECONDS) + } catch { + // Re-clear and keep waiting: we must not leave the worker running. + case _: InterruptedException => Thread.interrupted() + } + } + } + if (future != null) { + try { + // The worker has terminated, so this does not block; close the root it produced. + val root = future.get() + if (root != null) { + root.close() + } + } catch { + // The batch was never produced (cancelled/failed); nothing to close. + case _: java.util.concurrent.CancellationException => + case _: java.util.concurrent.ExecutionException => + case _: InterruptedException => // already terminated; nothing in flight + } + } + } finally { + if (wasInterrupted) { + Thread.currentThread().interrupt() + } + } + null + } + + def createColumnStats(dataType: DataType): ColumnStats = { + dataType match { + case BooleanType => new BooleanColumnStats + case ByteType => new ByteColumnStats + case ShortType => new ShortColumnStats + case IntegerType => new IntColumnStats + case DateType => new IntColumnStats // Date is stored as Int + case LongType => new LongColumnStats + case TimestampType => new LongColumnStats // Timestamp is stored as Long + case TimestampNTZType => new LongColumnStats // TimestampNTZ is stored as Long + case FloatType => new FloatColumnStats + case DoubleType => new DoubleColumnStats + case st: StringType => new StringColumnStats(st) + case BinaryType => new BinaryColumnStats + case dt: DecimalType => new DecimalColumnStats(dt) + case CalendarIntervalType => new IntervalColumnStats + case _: YearMonthIntervalType => new IntColumnStats // stored as Int + case _: DayTimeIntervalType => new LongColumnStats // stored as Long + case _: TimeType => new LongColumnStats // Time is stored as Long (nanoseconds) + case VariantType => new VariantColumnStats + // Geometry/Geography collect size/count without min/max bounds. Their physical value is a + // BinaryView (not Array[Byte]), so GeoColumnStats reads it via getBinaryView rather than + // BinaryColumnStats' getBinary, which would throw ClassCastException on a row that stores a + // BinaryView. They are also AtomicTypes that ColumnType (used by ObjectColumnStats) does not + // handle, so they must be matched explicitly here. + case _: GeometryType | _: GeographyType => new GeoColumnStats + // Unwrap UDTs to the same collector their underlying type would use. isSupportedByArrow + // accepts a UDT whenever its sqlType is supported (including Variant/Geometry/Geography), + // but ObjectColumnStats -> ColumnType(udt.sqlType) only unwraps one level and has no case + // for those types, so it would throw UNSUPPORTED_DATATYPE during materialization. Recursing + // here keeps the capability check and the statistics path in agreement. + case udt: UserDefinedType[_] => createColumnStats(udt.sqlType) + case _ => new ObjectColumnStats(dataType) + } + } + + def buildStatisticsFromCollectors( + collectors: Array[ColumnStats], + schema: Seq[Attribute]): InternalRow = { + val stats = collectors.flatMap { collector => + val collected = collector.collectedStatistics + // ColumnStats returns: [lowerBound, upperBound, nullCount, count, sizeInBytes] + Seq(collected(0), collected(1), collected(2), collected(3), collected(4)) + } + InternalRow.fromSeq(stats.toSeq) + } + + def collectStatistics( + root: VectorSchemaRoot, + schema: Seq[Attribute]): InternalRow = { + val rowCount = root.getRowCount + val vectors = root.getFieldVectors.asScala.toSeq + + // Collect stats for each column: lowerBound, upperBound, nullCount, rowCount, sizeInBytes + val stats = schema.zip(vectors).flatMap { case (attr, vector) => + val nullCount = (0 until rowCount).count(i => vector.isNull(i)) + val sizeInBytes = vector.getBufferSize.toLong + + val (lower, upper) = attr.dataType match { + case BooleanType => calculateMinMaxBoolean(vector, rowCount) + case ByteType => calculateMinMaxByte(vector, rowCount) + case ShortType => calculateMinMaxShort(vector, rowCount) + case IntegerType => calculateMinMaxInt(vector, rowCount) + case DateType => calculateMinMaxDate(vector, rowCount) + case LongType => calculateMinMaxLong(vector, rowCount) + case TimestampType => calculateMinMaxTimestamp(vector, rowCount) + case TimestampNTZType => calculateMinMaxTimestampNTZ(vector, rowCount) + case FloatType => calculateMinMaxFloat(vector, rowCount) + case DoubleType => calculateMinMaxDouble(vector, rowCount) + case st: StringType => calculateMinMaxString(vector, rowCount, st.collationId) + case _: DecimalType => calculateMinMaxDecimal(vector, rowCount, attr.dataType) + case _: YearMonthIntervalType => calculateMinMaxYearMonthInterval(vector, rowCount) + case _: DayTimeIntervalType => calculateMinMaxDayTimeInterval(vector, rowCount) + case _: TimeType => calculateMinMaxTime(vector, rowCount) + case _ => (null, null) // Skip for binary, complex, and other unsupported types + } + + Seq(lower, upper, nullCount, rowCount, sizeInBytes) + } + + new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(stats.toArray) + } + + def calculateMinMaxBoolean( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = true + var max = false + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.BitVector].get(i) != 0 + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxByte( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Byte.MaxValue + var max = Byte.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.TinyIntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxShort( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Short.MaxValue + var max = Short.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.SmallIntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxInt( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Int.MaxValue + var max = Int.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.IntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDate( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Int.MaxValue + var max = Int.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.DateDayVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxLong( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.BigIntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxTimestamp( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = + vector.asInstanceOf[org.apache.arrow.vector.TimeStampMicroTZVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxTimestampNTZ( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = + vector.asInstanceOf[org.apache.arrow.vector.TimeStampMicroVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxFloat( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Float.MaxValue + var max = Float.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.Float4Vector].get(i) + // Skip NaN: IEEE 754 comparisons with NaN are always false, so NaN never + // updates min/max in the row-based path (FloatColumnStats.gatherValueStats). + if (!value.isNaN) { + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDouble( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Double.MaxValue + var max = Double.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.Float8Vector].get(i) + // Skip NaN to match DoubleColumnStats.gatherValueStats. + if (!value.isNaN) { + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxString( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int, + collationId: Int = StringType.collationId): (Any, Any) = { + var min: org.apache.spark.unsafe.types.UTF8String = null + var max: org.apache.spark.unsafe.types.UTF8String = null + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val bytes = vector.asInstanceOf[org.apache.arrow.vector.VarCharVector].get(i) + val value = org.apache.spark.unsafe.types.UTF8String.fromBytes(bytes) + if (!hasValue) { + min = value.clone() + max = value.clone() + hasValue = true + } else { + if (value.semanticCompare(min, collationId) < 0) min = value.clone() + if (value.semanticCompare(max, collationId) > 0) max = value.clone() + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDecimal( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int, + dataType: org.apache.spark.sql.types.DataType): (Any, Any) = { + val decimalType = dataType.asInstanceOf[DecimalType] + var min: org.apache.spark.sql.types.Decimal = null + var max: org.apache.spark.sql.types.Decimal = null + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val bigDecimal = vector.asInstanceOf[ + org.apache.arrow.vector.DecimalVector].getObject(i) + val value = org.apache.spark.sql.types.Decimal( + bigDecimal, decimalType.precision, decimalType.scale) + + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value.compareTo(min) < 0) min = value + if (value.compareTo(max) > 0) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxYearMonthInterval( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Int.MaxValue + var max = Int.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.IntervalYearVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDayTimeInterval( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = org.apache.arrow.vector.DurationVector.get( + vector.asInstanceOf[org.apache.arrow.vector.DurationVector].getDataBuffer, i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxTime( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.TimeNanoVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } +} + +/** + * Iterator that converts InternalRow to ArrowCachedBatch. + */ +private class InternalRowToArrowCachedBatchIterator( + rowIter: Iterator[InternalRow], + schema: Seq[Attribute], + sparkSchema: StructType, + maxRecordsPerBatch: Long, + maxBytesPerBatch: Long, + timeZoneId: String, + compressionCodecName: String, + compressionLevel: Int) extends Iterator[ArrowCachedBatch] { + + private val compressionCodec = ArrowCachedBatchSerializer.createCompressionCodec( + compressionCodecName, + compressionLevel) + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"InternalRowToArrowCachedBatchIterator-${TaskContext.get().taskAttemptId()}", + 0, + Long.MaxValue) + + private val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId, false, false) + private val root = VectorSchemaRoot.create(arrowSchema, allocator) + private val arrowWriter = ArrowWriter.create(root) + private val unloader = new VectorUnloader(root, true, compressionCodec, true) + + // Create statistics collectors for each column + private val statsCollectors: Array[ColumnStats] = schema.map { attr => + ArrowCachedBatchSerializer.createColumnStats(attr.dataType) + }.toArray + + // Computed once: only CalendarInterval columns can overflow when written to Arrow nanoseconds. + private val hasCalendarInterval = ArrowCachedBatchSerializer.hasCalendarInterval(schema) + + // Register cleanup + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => + close() + } + } + + override def hasNext: Boolean = rowIter.hasNext || { + close() + false + } + + override def next(): ArrowCachedBatch = { + var rowCount = 0 + + // Reset statistics collectors for new batch + var idx = 0 + while (idx < statsCollectors.length) { + statsCollectors(idx) = ArrowCachedBatchSerializer.createColumnStats(schema(idx).dataType) + idx += 1 + } + + Utils.tryWithSafeFinally { + // Write rows to Arrow vectors and collect statistics incrementally, stopping when either the + // record-count or byte limit is reached (whichever is hit first), so wide rows cannot form + // multi-gigabyte batches that exhaust memory or overflow Arrow's 32-bit variable-width + // offsets. A nonpositive limit means that limit is unlimited; the `<= 0` guards also keep the + // loop from emitting empty batches forever. At least one row is always written so a single + // oversized row still makes progress. The byte limit is measured from the actual bytes + // already written to the Arrow vectors (arrowWriter.sizeInBytes), which is accurate for every + // row type -- a row-size estimate would undercount large values in a GenericInternalRow (e.g. + // a multi-megabyte string) and let the batch grow past the limit. + def recordLimitReached: Boolean = maxRecordsPerBatch > 0 && rowCount >= maxRecordsPerBatch + def byteLimitReached: Boolean = + maxBytesPerBatch > 0 && arrowWriter.sizeInBytes() >= maxBytesPerBatch + ArrowCachedBatchSerializer.withIntervalOverflowTranslation(hasCalendarInterval) { + while (rowIter.hasNext && (rowCount == 0 || (!recordLimitReached && !byteLimitReached))) { + val row = rowIter.next() + arrowWriter.write(row) + + // Collect statistics for this row + var i = 0 + while (i < statsCollectors.length) { + statsCollectors(i).gatherStats(row, i) + i += 1 + } + + rowCount += 1 + } + arrowWriter.finish() + } + + // Get the Arrow RecordBatch with compression + val recordBatch = unloader.getRecordBatch() + + Utils.tryWithSafeFinally { + // Serialize to Arrow IPC format + val arrowData = ArrowCachedBatchSerializer.serializeBatch(recordBatch) + + // Build statistics InternalRow from collected stats + val stats = ArrowCachedBatchSerializer.buildStatisticsFromCollectors( + statsCollectors, schema) + + ArrowCachedBatch(rowCount, arrowData, stats) + } { + recordBatch.close() + } + } { + arrowWriter.reset() + } + } + + private def close(): Unit = { + root.close() + allocator.close() + } +} + +/** + * Iterator that converts ColumnarBatch to ArrowCachedBatch. + */ +private class ColumnarBatchToArrowCachedBatchIterator( + batchIter: Iterator[ColumnarBatch], + schema: Seq[Attribute], + sparkSchema: StructType, + timeZoneId: String, + compressionCodecName: String, + compressionLevel: Int) extends Iterator[ArrowCachedBatch] { + + private val compressionCodec = ArrowCachedBatchSerializer.createCompressionCodec( + compressionCodecName, + compressionLevel) + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"ColumnarBatchToArrowCachedBatchIterator-${TaskContext.get().taskAttemptId()}", + 0, + Long.MaxValue) + + private val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId, false, false) + + // Computed once: only CalendarInterval columns can overflow when written to Arrow nanoseconds. + private val hasCalendarInterval = ArrowCachedBatchSerializer.hasCalendarInterval(schema) + + // Register cleanup + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => + allocator.close() + } + } + + override def hasNext: Boolean = batchIter.hasNext + + override def next(): ArrowCachedBatch = { + val batch = batchIter.next() + // Release the consumed input batch once converted. This iterator replaces the normal + // ColumnarToRow consumer (which calls closeIfFreeable() after each batch), so without this + // an Arrow-backed source's fresh off-heap vectors would stay live for every cached batch and + // grow executor memory until OOM. Both conversion branches finish synchronously and copy the + // data out (VectorUnloader / row conversion), so the input is safe to free here on success or + // failure; closeIfFreeable() is a no-op for reusable writable/constant vectors. + // One input ColumnarBatch maps to one cached batch: the upstream batch's row count is already + // bounded by the source's batch-size config (e.g. spark.sql.parquet.columnarReaderBatchSize), + // so no further record/byte splitting is needed here. + Utils.tryWithSafeFinally { + val rowCount = batch.numRows() + + // Check if batch is already Arrow-based for zero-copy path. The zero-copy path reuses the + // input vectors but serializes them under a schema built with largeVarTypes=false, and the + // read path reconstructs that same non-large schema. Large var-width vectors use 64-bit + // offsets, so reading them back under a 32-bit-offset schema would silently corrupt data. + // Fall back to the row-based conversion (which always produces standard var-width vectors) + // whenever any input vector is, or nests, a large var-width vector. + val vectors = (0 until batch.numCols()).map(batch.column) + val zeroCopyEligible = vectors.forall { + case acv: ArrowColumnVector => + !ColumnarBatchToArrowCachedBatchIterator.containsLargeVarType(acv.getValueVector) + case _ => false + } + if (zeroCopyEligible) { + // Fast path: zero-copy extraction of Arrow RecordBatch + convertArrowBatchZeroCopy(batch, rowCount, schema, vectors) + } else { + // Slow path: convert to Arrow via rows + convertToArrowBatch(batch, rowCount, schema) + } + } { + batch.closeIfFreeable() + } + } + + private def convertArrowBatchZeroCopy( + batch: ColumnarBatch, + rowCount: Int, + schema: Seq[Attribute], + vectors: Seq[ColumnVector]): ArrowCachedBatch = { + // Zero-copy path: extract Arrow vectors directly from ArrowColumnVector + val arrowVectors = vectors.map( + _.asInstanceOf[ArrowColumnVector].getValueVector.asInstanceOf[ + org.apache.arrow.vector.FieldVector]) + + // Create a VectorSchemaRoot from the existing vectors + val root = new VectorSchemaRoot(arrowSchema, arrowVectors.asJava, rowCount) + + Utils.tryWithSafeFinally { + // Use VectorUnloader to create compressed RecordBatch + val unloader = new VectorUnloader(root, true, compressionCodec, true) + val recordBatch = unloader.getRecordBatch() + + Utils.tryWithSafeFinally { + val arrowData = ArrowCachedBatchSerializer.serializeBatch(recordBatch) + val stats = ArrowCachedBatchSerializer.collectStatistics(root, schema) + ArrowCachedBatch(rowCount, arrowData, stats) + } { + recordBatch.close() + } + } { + // Note: We don't close the root here because we don't own the vectors + // They are owned by the input ColumnarBatch + } + } + + private def convertToArrowBatch( + batch: ColumnarBatch, + rowCount: Int, + schema: Seq[Attribute]): ArrowCachedBatch = { + // Convert columnar batch to rows, then to Arrow + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val arrowWriter = ArrowWriter.create(root) + val unloader = new VectorUnloader(root, true, compressionCodec, true) + + Utils.tryWithSafeFinally { + ArrowCachedBatchSerializer.withIntervalOverflowTranslation(hasCalendarInterval) { + val rowIterator = batch.rowIterator().asScala + while (rowIterator.hasNext) { + arrowWriter.write(rowIterator.next()) + } + arrowWriter.finish() + } + + val recordBatch = unloader.getRecordBatch() + Utils.tryWithSafeFinally { + val arrowData = ArrowCachedBatchSerializer.serializeBatch(recordBatch) + // Derive statistics from the built Arrow vectors rather than the input rows. The input + // rows here are ColumnarArray/ColumnarMap/ColumnarRow views (this is the non-Arrow + // ColumnarBatch path, e.g. vectorized nested Parquet/ORC), which do not expose a byte + // size; collecting row-by-row would record zero bytes for every complex value and make a + // complex-only relation report sizeInBytes=0, wrongly eligible for broadcast. Reading + // vector.getBufferSize off the finished root accounts for the actual payload, matching the + // zero-copy path. + val stats = ArrowCachedBatchSerializer.collectStatistics(root, schema) + ArrowCachedBatch(rowCount, arrowData, stats) + } { + recordBatch.close() + } + } { + arrowWriter.reset() + root.close() + } + } +} + +private object ColumnarBatchToArrowCachedBatchIterator { + import org.apache.arrow.vector.{FieldVector, LargeVarBinaryVector, LargeVarCharVector} + + /** + * Whether the vector is, or nests, a large var-width vector (64-bit offsets). These are not + * eligible for the zero-copy path because that path serializes and reloads under a schema built + * with largeVarTypes=false; reinterpreting 64-bit offset buffers as 32-bit would corrupt data. + */ + def containsLargeVarType(vector: org.apache.arrow.vector.ValueVector): Boolean = vector match { + case _: LargeVarCharVector | _: LargeVarBinaryVector => true + case fv: FieldVector => + fv.getChildrenFromFields.asScala.exists(containsLargeVarType) + case _ => false + } +} + +/** + * Iterator that converts ArrowCachedBatch to ColumnarBatch. + */ +private class ArrowCachedBatchToColumnarBatchIterator( + batchIter: Iterator[CachedBatch], + cacheSchema: StructType, + selectedSchema: StructType, + columnIndices: Array[Int], + timeZoneId: String, + prefetchEnabled: Boolean = false) extends Iterator[ColumnarBatch] { + + import java.util.concurrent.{Callable, ExecutionException, Executors, ExecutorService, Future} + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"ArrowCachedBatchToColumnarBatchIterator-${TaskContext.get().taskAttemptId()}", + 0, + Long.MaxValue) + + private val arrowSchema = ArrowUtils.toArrowSchema(cacheSchema, timeZoneId, false, false) + + // Track only the previous root to close it when next batch is produced + private var previousRoot: VectorSchemaRoot = null + + // Prefetch support: deserialize the next batch into its own root in a background thread while + // the current batch is being consumed. Only the deserialization (IPC read + decompression + + // loading into a fresh root) happens off-thread; closing the previous root stays on the + // consumer thread in next(), so the vectors backing a returned ColumnarBatch are never released + // while the consumer may still read them. + private val prefetchExecutor: ExecutorService = if (prefetchEnabled) { + Executors.newSingleThreadExecutor(r => { + val t = new Thread(r, "arrow-cache-prefetch") + t.setDaemon(true) + t + }) + } else { + null + } + private var prefetchFuture: Future[VectorSchemaRoot] = _ + + // Register cleanup - close remaining root and allocator when task completes + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => + // Stop the worker and close any root it already produced before closing the allocator. + // A short-circuiting consumer (e.g. LIMIT) can trigger task completion while a prefetched + // root is in flight; simply cancelling the future would drop that root and allocator.close() + // would then fail with "Memory was leaked by query". + prefetchFuture = ArrowCachedBatchSerializer.drainAndClosePrefetch( + prefetchExecutor, prefetchFuture) + if (previousRoot != null) { + previousRoot.close() + previousRoot = null + } + allocator.close() + } + } + + override def hasNext: Boolean = prefetchFuture != null || batchIter.hasNext + + override def next(): ColumnarBatch = { + // Close the previous root since the consumer has moved on from the batch it backed. + if (previousRoot != null) { + previousRoot.close() + previousRoot = null + } + + val root = if (prefetchFuture != null) { + val r = try { + prefetchFuture.get() + } catch { + case e: ExecutionException => throw e.getCause + } + prefetchFuture = null + r + } else { + deserializeToRoot(batchIter.next().asInstanceOf[ArrowCachedBatch]) + } + + previousRoot = root + + // Wrap vectors in ArrowColumnVector and project to selected columns. + val allColumns = root.getFieldVectors.asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + val selectedColumns = columnIndices.map(allColumns(_)) + val batch = new ColumnarBatch(selectedColumns, root.getRowCount) + + // Start prefetching the next batch while this one is being consumed. + submitPrefetch() + + batch + } + + /** Deserialize a cached batch into its own freshly-created root. Does not touch other roots. */ + private def deserializeToRoot(cachedBatch: ArrowCachedBatch): VectorSchemaRoot = { + val in = new ByteArrayInputStream(cachedBatch.arrowData) + val readChannel = new ReadChannel(Channels.newChannel(in)) + val recordBatch = MessageSerializer.deserializeRecordBatch(readChannel, allocator) + Utils.tryWithSafeFinally { + val root = VectorSchemaRoot.create(arrowSchema, allocator) + // VectorLoader.load fills vectors incrementally, so a failure (malformed data, decompression + // error, OOM) can occur after earlier vectors have allocated buffers. Close the partially + // loaded root on failure, otherwise it becomes unreachable and the later allocator.close() + // fails with a leak error that masks the original exception. + try { + val loader = new VectorLoader(root) + loader.load(recordBatch) + root + } catch { + case t: Throwable => + root.close() + throw t + } + } { + recordBatch.close() + } + } + + /** Submit deserialization of the next batch to the background thread, if prefetch is enabled. */ + private def submitPrefetch(): Unit = { + if (prefetchEnabled && batchIter.hasNext) { + val nextCachedBatch = batchIter.next().asInstanceOf[ArrowCachedBatch] + prefetchFuture = prefetchExecutor.submit(new Callable[VectorSchemaRoot] { + override def call(): VectorSchemaRoot = deserializeToRoot(nextCachedBatch) + }) + } + } +} + +/** + * A typed column reader that reads from an Arrow FieldVector and writes directly + * to an UnsafeRowWriter, avoiding per-row pattern matching overhead. + */ +private abstract class ArrowColumnReader { + def vector: org.apache.arrow.vector.FieldVector + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit + def setVector(v: org.apache.arrow.vector.FieldVector): Unit +} + +private object ArrowColumnReader { + import org.apache.arrow.vector._ + + def create(dataType: DataType): ArrowColumnReader = dataType match { + case BooleanType => new ArrowColumnReader { + private var _vector: BitVector = _ + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = _vector = v.asInstanceOf[BitVector] + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = + writer.write(ordinal, _vector.get(rowIndex) != 0) + } + case ByteType => new ArrowColumnReader { + private var _vector: TinyIntVector = _ + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = _vector = v.asInstanceOf[TinyIntVector] + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = + writer.write(ordinal, _vector.get(rowIndex)) + } + case ShortType => new ArrowColumnReader { + private var _vector: SmallIntVector = _ + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = _vector = v.asInstanceOf[SmallIntVector] + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = + writer.write(ordinal, _vector.get(rowIndex)) + } + case IntegerType | DateType | _: YearMonthIntervalType => new ArrowColumnReader { + private var _vector: FieldVector = _ + // Pre-bind accessor at setVector time to avoid per-row pattern match + private var _accessor: Int => Int = _ + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = { + _vector = v + _accessor = v match { + case iv: IntVector => iv.get + case dv: DateDayVector => dv.get + case iv: org.apache.arrow.vector.IntervalYearVector => iv.get + } + } + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = + writer.write(ordinal, _accessor(rowIndex)) + } + case LongType | TimestampType | TimestampNTZType | _: DayTimeIntervalType | _: TimeType => + new ArrowColumnReader { + private var _vector: FieldVector = _ + private var _accessor: Int => Long = _ + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = { + _vector = v + _accessor = v match { + case bv: BigIntVector => bv.get(_) + case tv: TimeStampMicroTZVector => tv.get(_) + case tv: TimeStampMicroVector => tv.get(_) + case dv: org.apache.arrow.vector.DurationVector => + i => org.apache.arrow.vector.DurationVector.get(dv.getDataBuffer, i) + case tv: org.apache.arrow.vector.TimeNanoVector => tv.get(_) + } + } + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = + writer.write(ordinal, _accessor(rowIndex)) + } + case FloatType => new ArrowColumnReader { + private var _vector: Float4Vector = _ + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = _vector = v.asInstanceOf[Float4Vector] + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = + writer.write(ordinal, _vector.get(rowIndex)) + } + case DoubleType => new ArrowColumnReader { + private var _vector: Float8Vector = _ + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = _vector = v.asInstanceOf[Float8Vector] + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = + writer.write(ordinal, _vector.get(rowIndex)) + } + case _: StringType => new ArrowColumnReader { + private var _vector: VarCharVector = _ + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = _vector = v.asInstanceOf[VarCharVector] + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = { + val bytes = _vector.get(rowIndex) + writer.write(ordinal, UTF8String.fromBytes(bytes)) + } + } + case BinaryType => new ArrowColumnReader { + private var _vector: VarBinaryVector = _ + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = _vector = v.asInstanceOf[VarBinaryVector] + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = + writer.write(ordinal, _vector.get(rowIndex)) + } + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + // Fast path for compact decimals (precision <= 18): + // Read the unscaled long directly from the Arrow buffer, zero allocation. + // Arrow stores Decimal as 128-bit little-endian integer in 16 bytes. + // For compact decimals, the value fits in the lower 8 bytes. + new ArrowColumnReader { + private var _vector: DecimalVector = _ + private var _dataBuffer: org.apache.arrow.memory.ArrowBuf = _ + private val typeWidth = DecimalVector.TYPE_WIDTH // 16 bytes + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = { + _vector = v.asInstanceOf[DecimalVector] + _dataBuffer = _vector.getDataBuffer + } + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = { + val startIndex = rowIndex.toLong * typeWidth + val unscaledLong = _dataBuffer.getLong(startIndex) + writer.write(ordinal, unscaledLong) + } + } + case dt: DecimalType => new ArrowColumnReader { + // Slow path for wide decimals (precision > 18): must go through BigDecimal + private var _vector: DecimalVector = _ + private val precision = dt.precision + private val scale = dt.scale + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = _vector = v.asInstanceOf[DecimalVector] + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = { + val decimal = Decimal(_vector.getObject(rowIndex), precision, scale) + writer.write(ordinal, decimal, precision, scale) + } + } + case _ => + throw new UnsupportedOperationException( + s"Complex type $dataType is handled by the fallback path") + } +} + +/** + * Fast-path iterator that converts ArrowCachedBatch to InternalRow. + * Uses pre-built typed column readers to avoid per-row pattern matching, + * and writes directly to UnsafeRowWriter to avoid intermediate SpecificInternalRow. + * Only used for schemas without complex types (Array/Struct/Map). + */ +private class ArrowCachedBatchToInternalRowIterator( + batchIter: Iterator[CachedBatch], + cacheSchema: StructType, + selectedSchema: StructType, + columnIndices: Array[Int], + timeZoneId: String, + prefetchEnabled: Boolean = false) extends Iterator[InternalRow] { + + import java.util.concurrent.{Callable, ExecutionException, Future, Executors, + ExecutorService} + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"ArrowCachedBatchToInternalRowIterator-${TaskContext.get().taskAttemptId()}", + 0, + Long.MaxValue) + + private var currentRoot: VectorSchemaRoot = null + private var currentRowIndex: Int = 0 + private var currentRowCount: Int = 0 + + private val numFields = selectedSchema.length + private val arrowSchema = ArrowUtils.toArrowSchema(cacheSchema, timeZoneId, false, false) + + // Pre-build typed readers per column at init time -- no per-row pattern match + private val columnReaders: Array[ArrowColumnReader] = + selectedSchema.fields.map(f => ArrowColumnReader.create(f.dataType)) + + // Write directly to UnsafeRow -- no intermediate SpecificInternalRow + UnsafeProjection + private val rowWriter = new UnsafeRowWriter(numFields) + + // Prefetch support: deserialize the next batch in background while current batch is consumed + private val prefetchExecutor: ExecutorService = if (prefetchEnabled) { + Executors.newSingleThreadExecutor(r => { + val t = new Thread(r, "arrow-cache-row-prefetch") + t.setDaemon(true) + t + }) + } else { + null + } + private var prefetchFuture: Future[VectorSchemaRoot] = _ + + // Register cleanup + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => + // Stop the worker and close any root it already produced before closing the allocator; + // otherwise a prefetched root produced after a short-circuiting consumer (e.g. LIMIT) stops + // reading would leak and allocator.close() would fail with "Memory was leaked by query". + prefetchFuture = ArrowCachedBatchSerializer.drainAndClosePrefetch( + prefetchExecutor, prefetchFuture) + if (currentRoot != null) { + currentRoot.close() + currentRoot = null + } + allocator.close() + } + } + + override def hasNext: Boolean = { + // Keep loading batches until the current one has rows or the input is exhausted. A cached + // batch can legitimately have zero rows (e.g. an empty ColumnarBatch from a columnar source); + // without this loop an empty batch would make hasNext return false and silently drop all + // remaining, non-empty batches. + while (currentRowIndex >= currentRowCount && (prefetchFuture != null || batchIter.hasNext)) { + loadNextBatch() + } + if (currentRowIndex < currentRowCount) { + true + } else { + if (currentRoot != null) { + currentRoot.close() + currentRoot = null + } + false + } + } + + override def next(): InternalRow = { + if (!hasNext) { + throw new NoSuchElementException("No more rows") + } + + rowWriter.reset() + rowWriter.zeroOutNullBytes() + + val rowIdx = currentRowIndex + var i = 0 + while (i < numFields) { + val reader = columnReaders(i) + if (reader.vector.isNull(rowIdx)) { + rowWriter.setNullAt(i) + } else { + reader.read(rowIdx, i, rowWriter) + } + i += 1 + } + + currentRowIndex += 1 + rowWriter.getRow() + } + + /** Deserialize a cached batch into a VectorSchemaRoot. */ + private def deserializeBatch(cachedBatch: ArrowCachedBatch): VectorSchemaRoot = { + val in = new ByteArrayInputStream(cachedBatch.arrowData) + val readChannel = new ReadChannel(Channels.newChannel(in)) + val recordBatch = MessageSerializer.deserializeRecordBatch(readChannel, allocator) + try { + val root = VectorSchemaRoot.create(arrowSchema, allocator) + // VectorLoader.load fills vectors incrementally, so a failure (malformed data, decompression + // error, OOM) can occur after earlier vectors have allocated buffers. Close the partially + // loaded root on failure, otherwise it becomes unreachable and the later allocator.close() + // fails with a leak error that masks the original exception. + try { + val loader = new VectorLoader(root) + loader.load(recordBatch) + root + } catch { + case t: Throwable => + root.close() + throw t + } + } finally { + recordBatch.close() + } + } + + /** Submit prefetch for the next batch if available. */ + private def submitPrefetch(): Unit = { + if (prefetchEnabled && batchIter.hasNext) { + val nextCachedBatch = batchIter.next().asInstanceOf[ArrowCachedBatch] + prefetchFuture = prefetchExecutor.submit(new Callable[VectorSchemaRoot] { + override def call(): VectorSchemaRoot = deserializeBatch(nextCachedBatch) + }) + } + } + + private def loadNextBatch(): Unit = { + if (currentRoot != null) { + currentRoot.close() + currentRoot = null + } + + val root = if (prefetchFuture != null) { + // Use the prefetched result + val r = try { + prefetchFuture.get() + } catch { + case e: ExecutionException => throw e.getCause + } + prefetchFuture = null + r + } else { + // No prefetch available, deserialize synchronously + val cachedBatch = batchIter.next().asInstanceOf[ArrowCachedBatch] + deserializeBatch(cachedBatch) + } + + currentRoot = root + + // Update pre-built readers with new vectors + var i = 0 + while (i < numFields) { + columnReaders(i).setVector(root.getVector(columnIndices(i))) + i += 1 + } + + currentRowIndex = 0 + currentRowCount = root.getRowCount + + // Start prefetching the next batch while this one is being consumed + submitPrefetch() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index c09c94ff4201d..26ccd382ba219 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{TimestampNanosVal, UTF8String} +import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarMap, ColumnarRow} +import org.apache.spark.unsafe.types.{BinaryView, TimestampNanosVal, UTF8String} class ColumnStatisticsSchema(a: Attribute) extends Serializable { val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)() @@ -297,6 +298,28 @@ private[columnar] final class BinaryColumnStats extends ColumnStats { Array[Any](null, null, nullCount, count, sizeInBytes) } +/** + * Size collector for Geometry/Geography columns. Their Catalyst physical value is a + * [[BinaryView]] (this is what ArrowWriter consumes via getBinaryView), so BinaryColumnStats, + * which reads row.getBinary, would throw a ClassCastException on a row that actually stores a + * BinaryView (e.g. a GenericInternalRow from a row-based reader or direct serializer use). Read + * the value through getBinaryView and use its byte length for the size; no min/max bounds. + */ +private[columnar] final class GeoColumnStats extends ColumnStats { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + if (!row.isNullAt(ordinal)) { + val view: BinaryView = row.getBinaryView(ordinal) + sizeInBytes += view.numBytes() + 4 + count += 1 + } else { + gatherNullStats() + } + } + + override def collectedStatistics: Array[Any] = + Array[Any](null, null, nullCount, count, sizeInBytes) +} + private[columnar] final class VariantColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { @@ -382,8 +405,28 @@ private[columnar] final class ObjectColumnStats(dataType: DataType) extends Colu override def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { - val size = columnType.actualSize(row, ordinal) - sizeInBytes += size + // Check if this is a columnar complex type that doesn't support getSizeInBytes + val isColumnarComplexType = columnType match { + case _: ARRAY => + row.getArray(ordinal).isInstanceOf[ColumnarArray] + case _: MAP => + row.getMap(ordinal).isInstanceOf[ColumnarMap] + case struct: STRUCT => + row.getStruct(ordinal, struct.dataType.fields.length).isInstanceOf[ColumnarRow] + case _ => + false + } + + if (!isColumnarComplexType) { + // Normal path: calculate size for unsafe types + // (UnsafeArrayData/UnsafeMapData/UnsafeRow) + val size = columnType.actualSize(row, ordinal) + sizeInBytes += size + } + // else: Skip size calculation for columnar complex types + // (ColumnarArray/ColumnarMap/ColumnarRow). These are views into ColumnVectors + // and don't expose getSizeInBytes() + count += 1 } else { gatherNullStats() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index f79742907779f..e79fad6c249d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -410,7 +410,7 @@ object InMemoryRelation { } /* Visible for testing */ - private[columnar] def clearSerializer(): Unit = synchronized { ser = None } + private[sql] def clearSerializer(): Unit = synchronized { ser = None } def apply( storageLevel: StorageLevel, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ArrowCacheBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ArrowCacheBenchmark.scala new file mode 100644 index 0000000000000..979811ef70e45 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ArrowCacheBenchmark.scala @@ -0,0 +1,810 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.internal.config.UI.UI_ENABLED +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} + +/** + * Benchmark to measure cache performance with Arrow format vs Default format. + * + * To run this benchmark: + * {{{ + * 1. without sbt: bin/spark-submit --class + * --jars , + * 2. build/sbt "sql/Test/runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/Test/runMain " + * Results will be written to "benchmarks/ArrowCacheBenchmark-results.txt". + * }}} + */ +object ArrowCacheBenchmark extends SqlBasedBenchmark { + + // Do NOT access the inherited `spark` session - it uses default serializer + // Instead, create fresh sessions for each benchmark + + // Create separate sessions for each cache format since SPARK_CACHE_SERIALIZER is static + // CRITICAL: Can only have one active SparkContext at a time + private def createFreshSession(serializer: String): SparkSession = { + // Stop any existing session and clear the registry + SparkSession.getActiveSession.foreach(_.stop()) + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + + // CRITICAL: Clear the cached serializer instance in InMemoryRelation + // This singleton is stored statically and persists across sessions + org.apache.spark.sql.execution.columnar.InMemoryRelation.clearSerializer() + + SparkSession.builder() + .master("local[1]") + .appName(s"ArrowCacheBenchmark-$serializer") + .config(SQLConf.SHUFFLE_PARTITIONS.key, 1) + .config(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, 1) + .config(UI_ENABLED.key, false) + .config(StaticSQLConf.SPARK_CACHE_SERIALIZER.key, serializer) + .getOrCreate() + } + + private def cachePrimitiveTypes(): Unit = { + val numRows = 5000000 // 5M rows for faster benchmarking + runBenchmark("Cache primitive types") { + val benchmark = new Benchmark("Cache 5M rows with primitives", numRows, output = output) + + // Run Default cache benchmark (with compression - default) + benchmark.addCase("Default cache - write + read") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Run Default cache without compression + benchmark.addCase("Default cache - write + read (uncompressed)") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + spark.conf.set("spark.sql.inMemoryColumnarStorage.compressed", "false") + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Run Arrow cache benchmark + benchmark.addCase("Arrow cache - write + read") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // NOTE: LZ4 compression benchmarks are commented out because Arrow's LZ4 implementation + // requires the optional lz4-java native library dependency. Without it, Arrow falls back + // to Apache Commons Compress pure-Java LZ4 implementation which is extremely slow + // (~50x slower than zstd). To enable fast LZ4 benchmarks, add this dependency to pom.xml: + // + // org.lz4 + // lz4-java + // 1.8.0 + // + + // // Run Arrow cache with lz4 compression benchmark + // benchmark.addCase("Arrow cache - write + read (lz4)") { _ => + // val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + // try { + // spark.conf.set(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key, "lz4") + // val df = spark.range(numRows).selectExpr( + // "id as int_col", + // "id * 2L as long_col", + // "cast(id as double) as double_col" + // ) + // df.cache() + // df.write.format("noop").mode("overwrite").save() + // df.unpersist(blocking = true) + // } finally { + // spark.stop() + // } + // } + + // Run Arrow cache with zstd level -1 (fastest) compression benchmark + benchmark.addCase("Arrow cache - write + read (zstd level -1)") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key, "zstd") + spark.conf.set(SQLConf.ARROW_EXECUTION_ZSTD_COMPRESSION_LEVEL.key, "-1") + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Run Arrow cache with zstd level 1 compression benchmark + benchmark.addCase("Arrow cache - write + read (zstd level 1)") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key, "zstd") + spark.conf.set(SQLConf.ARROW_EXECUTION_ZSTD_COMPRESSION_LEVEL.key, "1") + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Run Arrow cache with zstd level 3 (default) compression benchmark + benchmark.addCase("Arrow cache - write + read (zstd level 3)") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key, "zstd") + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.run() + } + } + + private def cacheWithFilters(): Unit = { + val numRows = 5000000 // 5M rows + // Each case times the whole cache build + materialization + filtered count, not the filter in + // isolation, and a fresh session is required per case because the cache serializer is resolved + // process-wide on first use. Both the default and Arrow serializers collect min/max bounds, so + // these numbers compare end-to-end cache+filter throughput between the two formats; they are + // not a measurement of partition pruning attributable to either format. + runBenchmark("Cache then filter") { + val benchmark = new Benchmark("Cache 5M rows, then filter", numRows, output = output) + + // Default cache filter benchmark (with compression - default) + benchmark.addCase("Default cache - filter") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + val df = spark.range(numRows).selectExpr( + "id as int_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + df.filter("int_col > 2500000").count() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Default cache filter without compression + benchmark.addCase("Default cache - filter (uncompressed)") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + spark.conf.set("spark.sql.inMemoryColumnarStorage.compressed", "false") + val df = spark.range(numRows).selectExpr( + "id as int_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + df.filter("int_col > 2500000").count() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Arrow cache filter benchmark + benchmark.addCase("Arrow cache - filter") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + val df = spark.range(numRows).selectExpr( + "id as int_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + df.filter("int_col > 2500000").count() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // // Arrow cache filter with lz4 compression + // benchmark.addCase("Arrow cache - filter (lz4)") { _ => + // val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + // try { + // spark.conf.set(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key, "lz4") + // val df = spark.range(numRows).selectExpr( + // "id as int_col", + // "cast(id as double) as double_col" + // ) + // df.cache() + // df.write.format("noop").mode("overwrite").save() // Materialize + // df.filter("int_col > 2500000").count() + // df.unpersist(blocking = true) + // } finally { + // spark.stop() + // } + // } + + // Arrow cache filter with zstd level -1 (fastest) + benchmark.addCase("Arrow cache - filter (zstd level -1)") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key, "zstd") + spark.conf.set(SQLConf.ARROW_EXECUTION_ZSTD_COMPRESSION_LEVEL.key, "-1") + val df = spark.range(numRows).selectExpr( + "id as int_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize + df.filter("int_col > 2500000").count() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Arrow cache filter with zstd level 1 + benchmark.addCase("Arrow cache - filter (zstd level 1)") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key, "zstd") + spark.conf.set(SQLConf.ARROW_EXECUTION_ZSTD_COMPRESSION_LEVEL.key, "1") + val df = spark.range(numRows).selectExpr( + "id as int_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + df.filter("int_col > 2500000").count() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Arrow cache filter with zstd level 3 + benchmark.addCase("Arrow cache - filter (zstd level 3)") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key, "zstd") + val df = spark.range(numRows).selectExpr( + "id as int_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + df.filter("int_col > 2500000").count() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.run() + } + } + + private def cacheColumnarInput(): Unit = { + val numRows = 2000000 // 2M rows + withTempPath { dir => + val path = dir.getAbsolutePath + + // Write parquet file using a temporary session + val tempSpark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + tempSpark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ).write.parquet(path) + } finally { + tempSpark.stop() + } + + runBenchmark("Cache columnar input (Parquet)") { + val benchmark = new Benchmark("Cache 2M rows from Parquet", numRows, output = output) + + benchmark.addCase("Default cache - columnar input") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + val parquet = spark.read.parquet(path) + parquet.cache() + parquet.write.format("noop").mode("overwrite").save() // Force read all data + parquet.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.addCase("Default cache - columnar input (uncompressed)") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + spark.conf.set("spark.sql.inMemoryColumnarStorage.compressed", "false") + val parquet = spark.read.parquet(path) + parquet.cache() + parquet.write.format("noop").mode("overwrite").save() // Force read all data + parquet.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.addCase("Arrow cache - columnar input") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + val parquet = spark.read.parquet(path) + parquet.cache() + parquet.write.format("noop").mode("overwrite").save() // Force read all data + parquet.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // benchmark.addCase("Arrow cache - columnar input (lz4)") { _ => + // val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + // try { + // spark.conf.set(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key, "lz4") + // val parquet = spark.read.parquet(path) + // parquet.cache() + // parquet.write.format("noop").mode("overwrite").save() // Force read all data + // parquet.unpersist(blocking = true) + // } finally { + // spark.stop() + // } + // } + + benchmark.addCase("Arrow cache - columnar input (zstd level -1)") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key, "zstd") + spark.conf.set(SQLConf.ARROW_EXECUTION_ZSTD_COMPRESSION_LEVEL.key, "-1") + val parquet = spark.read.parquet(path) + parquet.cache() + parquet.write.format("noop").mode("overwrite").save() // Force read all data + parquet.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.addCase("Arrow cache - columnar input (zstd level 1)") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key, "zstd") + spark.conf.set(SQLConf.ARROW_EXECUTION_ZSTD_COMPRESSION_LEVEL.key, "1") + val parquet = spark.read.parquet(path) + parquet.cache() + parquet.write.format("noop").mode("overwrite").save() // Force read all data + parquet.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.addCase("Arrow cache - columnar input (zstd level 3)") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key, "zstd") + val parquet = spark.read.parquet(path) + parquet.cache() + parquet.write.format("noop").mode("overwrite").save() // Force read all data + parquet.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.run() + } + } + } + + private def recacheArrowData(): Unit = { + val numRows = 2000000 // 2M rows + runBenchmark("Re-cache Arrow cached data (zero-copy test)") { + val benchmark = new Benchmark("Re-cache 2M rows (zero-copy)", numRows, output = output) + + benchmark.addTimerCase("Default cache - cache a cached DF") { timer => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + // Create and cache initial data (NOT timed) + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + + // START TIMING: Cache the cached DataFrame again + val df2 = df.drop("double_col") + timer.startTiming() + df2.cache() + df2.write.format("noop").mode("overwrite").save() // Force read all data + timer.stopTiming() + + df2.unpersist(blocking = true) + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.addTimerCase("Default cache - cache a cached DF (uncompressed)") { timer => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + spark.conf.set("spark.sql.inMemoryColumnarStorage.compressed", "false") + // Create and cache initial data (NOT timed) + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + + // START TIMING: Cache the cached DataFrame again + val df2 = df.drop("double_col") + timer.startTiming() + df2.cache() + df2.write.format("noop").mode("overwrite").save() // Force read all data + timer.stopTiming() + + df2.unpersist(blocking = true) + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.addTimerCase("Arrow cache - cache a cached DF (zero-copy)") { timer => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + // Create and cache initial data (NOT timed) + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + + // START TIMING: Cache the cached DataFrame again + // Drop a column to create a different logical plan + // This preserves ArrowColumnVector for remaining columns, enabling zero-copy + val df2 = df.drop("double_col") + timer.startTiming() + df2.cache() + df2.write.format("noop").mode("overwrite").save() // Force read all data + timer.stopTiming() + + df2.unpersist(blocking = true) + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // benchmark.addTimerCase("Arrow cache - cache a cached DF (lz4)") { timer => + // val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + // try { + // spark.conf.set(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key, "lz4") + // // Create and cache initial data (NOT timed) + // val df = spark.range(numRows).selectExpr( + // "id as int_col", + // "id * 2L as long_col", + // "cast(id as double) as double_col" + // ) + // df.cache() + // df.write.format("noop").mode("overwrite").save() // Materialize + + // // START TIMING: Cache the cached DataFrame again + // val df2 = df.drop("double_col") + // timer.startTiming() + // df2.cache() + // df2.write.format("noop").mode("overwrite").save() // Force read all data + // timer.stopTiming() + + // df2.unpersist(blocking = true) + // df.unpersist(blocking = true) + // } finally { + // spark.stop() + // } + // } + + benchmark.addTimerCase("Arrow cache - cache a cached DF (zstd level -1)") { timer => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key, "zstd") + spark.conf.set(SQLConf.ARROW_EXECUTION_ZSTD_COMPRESSION_LEVEL.key, "-1") + // Create and cache initial data (NOT timed) + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + + // START TIMING: Cache the cached DataFrame again + val df2 = df.drop("double_col") + timer.startTiming() + df2.cache() + df2.write.format("noop").mode("overwrite").save() // Force read all data + timer.stopTiming() + + df2.unpersist(blocking = true) + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.addTimerCase("Arrow cache - cache a cached DF (zstd level 1)") { timer => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key, "zstd") + spark.conf.set(SQLConf.ARROW_EXECUTION_ZSTD_COMPRESSION_LEVEL.key, "1") + // Create and cache initial data (NOT timed) + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + + // START TIMING: Cache the cached DataFrame again + val df2 = df.drop("double_col") + timer.startTiming() + df2.cache() + df2.write.format("noop").mode("overwrite").save() // Force read all data + timer.stopTiming() + + df2.unpersist(blocking = true) + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.addTimerCase("Arrow cache - cache a cached DF (zstd level 3)") { timer => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key, "zstd") + // Create and cache initial data (NOT timed) + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + + // START TIMING: Cache the cached DataFrame again + val df2 = df.drop("double_col") + timer.startTiming() + df2.cache() + df2.write.format("noop").mode("overwrite").save() // Force read all data + timer.stopTiming() + + df2.unpersist(blocking = true) + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.run() + } + } + + private def columnPruning(): Unit = { + val numRows = 5000000 // 5M rows + runBenchmark("Cache with column pruning (select 1 of 20 columns)") { + val benchmark = new Benchmark( + "Cache 5M rows, select 1 column", numRows, output = output) + + // Run Default cache benchmark (with compression - default) + benchmark.addCase("Default cache - select 1 of 20 columns") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + // Create DataFrame with 20 columns + val df = spark.range(numRows).selectExpr( + (0 until 20).map(i => s"id + $i as col$i"): _* + ) + df.cache() + df.count() // Materialize cache + + // Select only first column and count + df.select("col0").write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Run Default cache without compression + benchmark.addCase("Default cache - select 1 of 20 (uncompressed)") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + spark.conf.set("spark.sql.inMemoryColumnarStorage.compressed", "false") + val df = spark.range(numRows).selectExpr( + (0 until 20).map(i => s"id + $i as col$i"): _* + ) + df.cache() + df.count() + df.select("col0").write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Run Arrow cache (no compression) + benchmark.addCase("Arrow cache - select 1 of 20") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer") + try { + spark.conf.set(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key, "true") + val df = spark.range(numRows).selectExpr( + (0 until 20).map(i => s"id + $i as col$i"): _* + ) + df.cache() + df.count() + df.select("col0").write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Run Arrow cache with zstd compression level -1 + benchmark.addCase("Arrow cache - select 1 of 20 (zstd level -1)") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer") + try { + spark.conf.set(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key, "true") + spark.conf.set(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key, "zstd") + spark.conf.set(SQLConf.ARROW_EXECUTION_ZSTD_COMPRESSION_LEVEL.key, "-1") + val df = spark.range(numRows).selectExpr( + (0 until 20).map(i => s"id + $i as col$i"): _* + ) + df.cache() + df.count() + df.select("col0").write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Run Arrow cache with zstd compression level 1 + benchmark.addCase("Arrow cache - select 1 of 20 (zstd level 1)") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer") + try { + spark.conf.set(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key, "true") + spark.conf.set(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key, "zstd") + spark.conf.set(SQLConf.ARROW_EXECUTION_ZSTD_COMPRESSION_LEVEL.key, "1") + val df = spark.range(numRows).selectExpr( + (0 until 20).map(i => s"id + $i as col$i"): _* + ) + df.cache() + df.count() + df.select("col0").write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Run Arrow cache with zstd compression level 3 + benchmark.addCase("Arrow cache - select 1 of 20 (zstd level 3)") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer") + try { + spark.conf.set(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key, "true") + spark.conf.set(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key, "zstd") + spark.conf.set(SQLConf.ARROW_EXECUTION_ZSTD_COMPRESSION_LEVEL.key, "3") + val df = spark.range(numRows).selectExpr( + (0 until 20).map(i => s"id + $i as col$i"): _* + ) + df.cache() + df.count() + df.select("col0").write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.run() + } + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("Arrow Cache vs Default Cache") { + cachePrimitiveTypes() + cacheWithFilters() + cacheColumnarInput() + recacheArrowData() + columnPruning() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializerSuite.scala new file mode 100644 index 0000000000000..52c376142a43f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializerSuite.scala @@ -0,0 +1,2301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.sql.{Date, Timestamp} +import java.time.{Duration, LocalDateTime, LocalTime, Period} + +import org.apache.arrow.vector.{ + BigIntVector, BitVector, DateDayVector, DecimalVector, + Float4Vector, Float8Vector, IntVector, LargeVarCharVector, SmallIntVector, + TimeNanoVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, + VarBinaryVector, VarCharVector, VectorSchemaRoot, VectorUnloader} + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInternalRow} +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.types.variant.VariantBuilder +import org.apache.spark.unsafe.types.{BinaryView, CalendarInterval, VariantVal} +import org.apache.spark.util.Utils + +/** UDT whose sqlType is Arrow-supported (ArrayType(DoubleType)). */ +private class SupportedUDT extends UserDefinedType[Array[Double]] { + override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) + override def serialize(obj: Array[Double]): Any = obj + override def deserialize(datum: Any): Array[Double] = datum.asInstanceOf[Array[Double]] + override def userClass: Class[Array[Double]] = classOf[Array[Double]] +} + +/** UDT whose sqlType is ObjectType - not supported by Arrow. */ +private class UnsupportedUDT extends UserDefinedType[AnyRef] { + override def sqlType: DataType = ObjectType(classOf[AnyRef]) + override def serialize(obj: AnyRef): Any = obj + override def deserialize(datum: Any): AnyRef = datum.asInstanceOf[AnyRef] + override def userClass: Class[AnyRef] = classOf[AnyRef] +} + +class ArrowCachedBatchSerializerSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + override protected def sparkConf = { + super.sparkConf + .set(StaticSQLConf.SPARK_CACHE_SERIALIZER.key, + classOf[ArrowCachedBatchSerializer].getName) + .set(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key, "false") + } + + // InMemoryRelation caches the serializer instance in a process-wide field that is initialized + // from spark.sql.cache.serializer only on first use. When another suite runs first in the same + // JVM, that field is already bound to DefaultCachedBatchSerializer, so reset it here to pick up + // the Arrow serializer configured above, and reset it again afterwards so we do not leak the + // Arrow serializer to later suites. + override def beforeAll(): Unit = { + super.beforeAll() + InMemoryRelation.clearSerializer() + } + + override def afterAll(): Unit = { + InMemoryRelation.clearSerializer() + super.afterAll() + } + + test("basic caching with primitive types") { + val df = Seq( + (1, 2L, 3.0f, 4.0, "hello"), + (5, 6L, 7.0f, 8.0, "world"), + (9, 10L, 11.0f, 12.0, "test") + ).toDF("a", "b", "c", "d", "e") + + df.cache() + checkAnswer(df, Seq( + Row(1, 2L, 3.0f, 4.0, "hello"), + Row(5, 6L, 7.0f, 8.0, "world"), + Row(9, 10L, 11.0f, 12.0, "test") + )) + + // Verify it was actually cached + assert(df.storageLevel.useMemory) + } + + test("caching with all primitive types") { + val df = Seq( + (true, 1.toByte, 2.toShort, 3, 4L, 5.0f, 6.0), + (false, 7.toByte, 8.toShort, 9, 10L, 11.0f, 12.0), + (true, 13.toByte, 14.toShort, 15, 16L, 17.0f, 18.0) + ).toDF("bool", "byte", "short", "int", "long", "float", "double") + + df.cache() + checkAnswer(df, Seq( + Row(true, 1.toByte, 2.toShort, 3, 4L, 5.0f, 6.0), + Row(false, 7.toByte, 8.toShort, 9, 10L, 11.0f, 12.0), + Row(true, 13.toByte, 14.toShort, 15, 16L, 17.0f, 18.0) + )) + } + + test("caching with null values") { + val df = Seq( + (Some(1), Some("a")), + (None, Some("b")), + (Some(3), None), + (None, None) + ).toDF("num", "str") + + df.cache() + checkAnswer(df, Seq( + Row(1, "a"), + Row(null, "b"), + Row(3, null), + Row(null, null) + )) + } + + test("caching with date and timestamp types") { + val date1 = Date.valueOf("2020-01-01") + val date2 = Date.valueOf("2021-06-15") + val ts1 = Timestamp.valueOf("2020-01-01 12:00:00") + val ts2 = Timestamp.valueOf("2021-06-15 15:30:45") + + val df = Seq( + (date1, ts1), + (date2, ts2) + ).toDF("date", "timestamp") + + df.cache() + checkAnswer(df, Seq( + Row(date1, ts1), + Row(date2, ts2) + )) + } + + test("caching with decimal types") { + val df = Seq( + BigDecimal("123.45"), + BigDecimal("678.90"), + BigDecimal("999.99") + ).toDF("decimal") + + df.cache() + checkAnswer(df, Seq( + Row(BigDecimal("123.45")), + Row(BigDecimal("678.90")), + Row(BigDecimal("999.99")) + )) + } + + test("caching with binary type") { + val df = Seq( + "hello".getBytes("UTF-8"), + "world".getBytes("UTF-8"), + "test".getBytes("UTF-8") + ).toDF("binary") + + df.cache() + val result = df.collect() + assert(result.length == 3) + assert(new String(result(0).getAs[Array[Byte]](0), "UTF-8") == "hello") + assert(new String(result(1).getAs[Array[Byte]](0), "UTF-8") == "world") + assert(new String(result(2).getAs[Array[Byte]](0), "UTF-8") == "test") + } + + test("caching with array type") { + val df = Seq( + Seq(1, 2, 3), + Seq(4, 5, 6), + Seq(7, 8, 9) + ).toDF("array") + + df.cache() + checkAnswer(df, Seq( + Row(Seq(1, 2, 3)), + Row(Seq(4, 5, 6)), + Row(Seq(7, 8, 9)) + )) + } + + test("caching with struct type") { + val df = Seq( + (1, ("a", 10)), + (2, ("b", 20)), + (3, ("c", 30)) + ).toDF("id", "struct") + + df.cache() + checkAnswer(df, Seq( + Row(1, Row("a", 10)), + Row(2, Row("b", 20)), + Row(3, Row("c", 30)) + )) + } + + test("caching with map type") { + val df = Seq( + Map("a" -> 1, "b" -> 2), + Map("c" -> 3, "d" -> 4), + Map("e" -> 5, "f" -> 6) + ).toDF("map") + + df.cache() + checkAnswer(df, Seq( + Row(Map("a" -> 1, "b" -> 2)), + Row(Map("c" -> 3, "d" -> 4)), + Row(Map("e" -> 5, "f" -> 6)) + )) + } + + test("caching with nested complex types") { + val df = Seq( + (1, Seq(("a", Seq(1, 2)), ("b", Seq(3, 4)))), + (2, Seq(("c", Seq(5, 6)), ("d", Seq(7, 8)))) + ).toDF("id", "nested") + + df.cache() + checkAnswer(df, Seq( + Row(1, Seq(Row("a", Seq(1, 2)), Row("b", Seq(3, 4)))), + Row(2, Seq(Row("c", Seq(5, 6)), Row("d", Seq(7, 8)))) + )) + } + + test("caching with filter pushdown") { + val df = (1 to 100).map(i => (i, i * 2, s"str$i")).toDF("a", "b", "c") + df.cache() + + // This should use cached data with filter + val filtered = df.filter($"a" > 50) + checkAnswer(filtered, (51 to 100).map(i => Row(i, i * 2, s"str$i"))) + + // Verify cache was used + assert(filtered.queryExecution.executedPlan.toString.contains("InMemoryTableScan")) + } + + test("caching with column projection") { + val df = (1 to 100).map(i => (i, i * 2, i * 3, s"str$i")).toDF("a", "b", "c", "d") + df.cache() + + // Select subset of columns + val projected = df.select("a", "c") + checkAnswer(projected, (1 to 100).map(i => Row(i, i * 3))) + + // Verify cache was used + assert(projected.queryExecution.executedPlan.toString.contains("InMemoryTableScan")) + } + + test("caching with multiple batches") { + withSQLConf(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> "10") { + val df = (1 to 50).map(i => (i, s"str$i")).toDF("a", "b") + df.cache() + + checkAnswer(df, (1 to 50).map(i => Row(i, s"str$i"))) + + // Verify multiple batches were created + val plan = df.queryExecution.executedPlan + val inMemoryScan = plan.collectFirst { + case scan: InMemoryTableScanExec => scan + } + assert(inMemoryScan.isDefined) + } + } + + test("uncache and recache") { + val df = Seq((1, "a"), (2, "b"), (3, "c")).toDF("id", "value") + + // Cache + df.cache() + checkAnswer(df, Seq(Row(1, "a"), Row(2, "b"), Row(3, "c"))) + assert(df.storageLevel.useMemory) + + // Uncache + df.unpersist() + assert(!df.storageLevel.useMemory) + + // Recache + df.cache() + checkAnswer(df, Seq(Row(1, "a"), Row(2, "b"), Row(3, "c"))) + assert(df.storageLevel.useMemory) + } + + test("cache with aggregation") { + val df = Seq( + ("a", 1), + ("b", 2), + ("a", 3), + ("b", 4), + ("a", 5) + ).toDF("key", "value") + + df.cache() + + val agg = df.groupBy("key").sum("value") + checkAnswer(agg, Seq(Row("a", 9), Row("b", 6))) + } + + test("cache with join") { + val df1 = Seq((1, "a"), (2, "b"), (3, "c")).toDF("id", "value1") + val df2 = Seq((1, "x"), (2, "y"), (3, "z")).toDF("id", "value2") + + df1.cache() + df2.cache() + + val joined = df1.join(df2, "id") + checkAnswer(joined, Seq( + Row(1, "a", "x"), + Row(2, "b", "y"), + Row(3, "c", "z") + )) + } + + test("vectorized reader enabled") { + withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> "true") { + val df = Seq((1, "a"), (2, "b"), (3, "c")).toDF("id", "value") + df.cache() + + checkAnswer(df, Seq(Row(1, "a"), Row(2, "b"), Row(3, "c"))) + + // Verify vectorized reader is used + val plan = df.queryExecution.executedPlan + val inMemoryScan = plan.collectFirst { + case scan: InMemoryTableScanExec => scan + } + assert(inMemoryScan.isDefined) + assert(inMemoryScan.get.supportsColumnar) + } + } + + test("compression codec - none") { + withSQLConf(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key -> "none") { + val df = Seq((1, "a"), (2, "b"), (3, "c")).toDF("id", "value") + df.cache() + checkAnswer(df, Seq(Row(1, "a"), Row(2, "b"), Row(3, "c"))) + } + } + + test("compression codec - zstd") { + withSQLConf(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key -> "zstd") { + val df = Seq((1, "a"), (2, "b"), (3, "c")).toDF("id", "value") + df.cache() + checkAnswer(df, Seq(Row(1, "a"), Row(2, "b"), Row(3, "c"))) + } + } + + test("compression codec - lz4") { + withSQLConf(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key -> "lz4") { + val df = Seq((1, "a"), (2, "b"), (3, "c")).toDF("id", "value") + df.cache() + checkAnswer(df, Seq(Row(1, "a"), Row(2, "b"), Row(3, "c"))) + } + } + + test("large dataset") { + val df = (1 to 10000).map(i => (i, i * 2, s"string$i")).toDF("a", "b", "c") + df.cache() + + checkAnswer( + df.filter($"a" > 9000), + (9001 to 10000).map(i => Row(i, i * 2, s"string$i")) + ) + } + + test("empty dataset") { + val df = Seq.empty[(Int, String)].toDF("id", "value") + df.cache() + checkAnswer(df, Seq.empty[Row]) + } + + test("single row") { + val df = Seq((1, "single")).toDF("id", "value") + df.cache() + checkAnswer(df, Seq(Row(1, "single"))) + } + + test("cache table command") { + withTempView("test_table") { + Seq((1, "a"), (2, "b"), (3, "c")).toDF("id", "value") + .createOrReplaceTempView("test_table") + + sql("CACHE TABLE test_table") + + checkAnswer( + sql("SELECT * FROM test_table"), + Seq(Row(1, "a"), Row(2, "b"), Row(3, "c")) + ) + + sql("UNCACHE TABLE test_table") + } + } + + test("columnar batch from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = (1 to 100).map(i => (i, i * 2, s"str$i")).toDF("a", "b", "c") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, (1 to 100).map(i => Row(i, i * 2, s"str$i"))) + } + } + + test("supportsColumnarInput selects the input path and does not gate type support") { + // supportsColumnarInput is a path selector, not a capability gate: it returns true regardless + // of schema, and the actual type support is enforced by checkSupportedSchema at the convert + // entry points. (There is no per-type fallback to another serializer, so gating here would + // not help anyway.) + val serializer = new ArrowCachedBatchSerializer() + assert(serializer.supportsColumnarInput(Seq(AttributeReference("int", IntegerType)()))) + assert(serializer.supportsColumnarInput( + Seq(AttributeReference("obj", ObjectType(classOf[AnyRef]))()))) + } + + test("checkSupportedSchema accepts supported types and rejects unsupported ones") { + // Supported schemas (primitive, temporal, decimal, complex, nested) must not throw. + val supported = Seq( + Seq(AttributeReference("bool", BooleanType)(), AttributeReference("long", LongType)(), + AttributeReference("string", StringType)(), AttributeReference("binary", BinaryType)()), + Seq(AttributeReference("date", DateType)(), AttributeReference("ts", TimestampType)(), + AttributeReference("tsNtz", TimestampNTZType)()), + Seq(AttributeReference("decimal", DecimalType(10, 2))()), + Seq(AttributeReference("array", ArrayType(IntegerType))(), + AttributeReference("map", MapType(StringType, IntegerType))()), + Seq(AttributeReference("nested", ArrayType(StructType(Seq( + StructField("x", IntegerType), StructField("y", ArrayType(StringType))))))()) + ) + supported.foreach(ArrowCachedBatchSerializer.checkSupportedSchema) + + // An unsupported type (ObjectType, also via an unsupported-sqlType UDT) must throw a clear + // error rather than failing deeper in conversion. + val e = intercept[SparkException] { + ArrowCachedBatchSerializer.checkSupportedSchema( + Seq(AttributeReference("obj", ObjectType(classOf[AnyRef]))())) + } + assert(e.getMessage.contains("Arrow cache does not support")) + intercept[SparkException] { + ArrowCachedBatchSerializer.checkSupportedSchema( + Seq(AttributeReference("udt", new UnsupportedUDT())())) + } + } + + test("isSupportedByArrow correctly validates all types") { + // Verify that isSupportedByArrow handles all standard Spark SQL types + assert(ArrowUtils.isSupportedByArrow(BooleanType)) + assert(ArrowUtils.isSupportedByArrow(ByteType)) + assert(ArrowUtils.isSupportedByArrow(ShortType)) + assert(ArrowUtils.isSupportedByArrow(IntegerType)) + assert(ArrowUtils.isSupportedByArrow(LongType)) + assert(ArrowUtils.isSupportedByArrow(FloatType)) + assert(ArrowUtils.isSupportedByArrow(DoubleType)) + assert(ArrowUtils.isSupportedByArrow(StringType)) + assert(ArrowUtils.isSupportedByArrow(BinaryType)) + assert(ArrowUtils.isSupportedByArrow(DateType)) + assert(ArrowUtils.isSupportedByArrow(TimestampType)) + assert(ArrowUtils.isSupportedByArrow(TimestampNTZType)) + assert(ArrowUtils.isSupportedByArrow(DecimalType(10, 2))) + assert(ArrowUtils.isSupportedByArrow(NullType)) + assert(ArrowUtils.isSupportedByArrow(CalendarIntervalType)) + + // Complex types + assert(ArrowUtils.isSupportedByArrow(ArrayType(IntegerType))) + assert(ArrowUtils.isSupportedByArrow(StructType(Seq(StructField("x", IntegerType))))) + assert(ArrowUtils.isSupportedByArrow(MapType(StringType, IntegerType))) + + // Nested complex types + assert(ArrowUtils.isSupportedByArrow( + ArrayType(StructType(Seq( + StructField("a", IntegerType), + StructField("b", ArrayType(StringType)) + ))) + )) + + // UDT: delegates to sqlType - supported when sqlType is Arrow-compatible + // ExamplePointUDT.sqlType = ArrayType(DoubleType) -> supported + assert(ArrowUtils.isSupportedByArrow(new ExamplePointUDT()), + "UDT with Arrow-supported sqlType should be supported") + assert(ArrowUtils.isSupportedByArrow(new SupportedUDT()), + "UDT with ArrayType(DoubleType) sqlType should be supported") + // UDT with ObjectType sqlType -> not supported (ObjectType is internal, not an Arrow type) + assert(!ArrowUtils.isSupportedByArrow(new UnsupportedUDT()), + "UDT with ObjectType sqlType should not be supported") + } + + test("verify Arrow cache serializer is actually used") { + val df = Seq((1, "a"), (2, "b"), (3, "c")).toDF("id", "value") + df.cache() + df.count() // Materialize the cache + + // Verify the query plan uses InMemoryTableScan + val plan = df.queryExecution.executedPlan + val inMemoryScan = plan.collectFirst { + case scan: InMemoryTableScanExec => scan + } + assert(inMemoryScan.isDefined, "InMemoryTableScan should be present in cached query plan") + + // Verify the serializer is ArrowCachedBatchSerializer + val serializer = inMemoryScan.get.relation.cacheBuilder.serializer + assert(serializer.isInstanceOf[ArrowCachedBatchSerializer], + s"Expected ArrowCachedBatchSerializer but got ${serializer.getClass.getName}") + } + + test("columnar input with array type from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, Seq(1, 2, 3)), + (2, Seq(4, 5, 6)), + (3, Seq(7, 8, 9)) + ).toDF("id", "array_col") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Seq(1, 2, 3)), + Row(2, Seq(4, 5, 6)), + Row(3, Seq(7, 8, 9)) + )) + + // Verify cache was used + assert(cached.storageLevel.useMemory) + } + } + + test("columnar input with struct type from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, ("a", 10)), + (2, ("b", 20)), + (3, ("c", 30)) + ).toDF("id", "struct_col") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Row("a", 10)), + Row(2, Row("b", 20)), + Row(3, Row("c", 30)) + )) + + // Verify cache was used + assert(cached.storageLevel.useMemory) + } + } + + test("columnar input with map type from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, Map("a" -> 1, "b" -> 2)), + (2, Map("c" -> 3, "d" -> 4)), + (3, Map("e" -> 5, "f" -> 6)) + ).toDF("id", "map_col") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Map("a" -> 1, "b" -> 2)), + Row(2, Map("c" -> 3, "d" -> 4)), + Row(3, Map("e" -> 5, "f" -> 6)) + )) + + // Verify cache was used + assert(cached.storageLevel.useMemory) + } + } + + test("columnar input with nested complex types from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, Seq(("a", Seq(1, 2)), ("b", Seq(3, 4)))), + (2, Seq(("c", Seq(5, 6)), ("d", Seq(7, 8)))) + ).toDF("id", "nested_col") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Seq(Row("a", Seq(1, 2)), Row("b", Seq(3, 4)))), + Row(2, Seq(Row("c", Seq(5, 6)), Row("d", Seq(7, 8)))) + )) + + // Verify cache was used + assert(cached.storageLevel.useMemory) + } + } + + test("columnar input with array of structs from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, Seq(("apple", 1.5), ("banana", 2.0))), + (2, Seq(("orange", 1.8), ("grape", 3.5))), + (3, Seq(("mango", 2.5))) + ).toDF("id", "items") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Seq(Row("apple", 1.5), Row("banana", 2.0))), + Row(2, Seq(Row("orange", 1.8), Row("grape", 3.5))), + Row(3, Seq(Row("mango", 2.5))) + )) + + // Verify cache was used and operations work + val filtered = cached.filter($"id" > 1) + checkAnswer(filtered, Seq( + Row(2, Seq(Row("orange", 1.8), Row("grape", 3.5))), + Row(3, Seq(Row("mango", 2.5))) + )) + } + } + + test("columnar input with struct containing arrays from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, ("user1", Seq("tag1", "tag2", "tag3"))), + (2, ("user2", Seq("tag4", "tag5"))), + (3, ("user3", Seq("tag6"))) + ).toDF("id", "user_info") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Row("user1", Seq("tag1", "tag2", "tag3"))), + Row(2, Row("user2", Seq("tag4", "tag5"))), + Row(3, Row("user3", Seq("tag6"))) + )) + + // Verify we can access nested fields + val extracted = cached.select($"id", $"user_info._1".as("name")) + checkAnswer(extracted, Seq( + Row(1, "user1"), + Row(2, "user2"), + Row(3, "user3") + )) + } + } + + test("columnar input with map of arrays from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, Map("a" -> Seq(1, 2, 3), "b" -> Seq(4, 5))), + (2, Map("c" -> Seq(6, 7), "d" -> Seq(8, 9, 10))) + ).toDF("id", "map_of_arrays") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Map("a" -> Seq(1, 2, 3), "b" -> Seq(4, 5))), + Row(2, Map("c" -> Seq(6, 7), "d" -> Seq(8, 9, 10))) + )) + + // Verify cache was used + assert(cached.storageLevel.useMemory) + } + } + + test("columnar input with null values in complex types from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, Some(Seq(1, 2, 3)), Some(("a", 10))), + (2, None, Some(("b", 20))), + (3, Some(Seq(4, 5)), None), + (4, None, None) + ).toDF("id", "array_col", "struct_col") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Seq(1, 2, 3), Row("a", 10)), + Row(2, null, Row("b", 20)), + Row(3, Seq(4, 5), null), + Row(4, null, null) + )) + + // Verify filtering works with nulls + val filtered = cached.filter($"array_col".isNotNull) + checkAnswer(filtered, Seq( + Row(1, Seq(1, 2, 3), Row("a", 10)), + Row(3, Seq(4, 5), null) + )) + } + } + + test("columnar input with empty arrays and maps from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, Seq(1, 2, 3), Map("a" -> 1)), + (2, Seq.empty[Int], Map.empty[String, Int]), + (3, Seq(4), Map("b" -> 2, "c" -> 3)) + ).toDF("id", "array_col", "map_col") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Seq(1, 2, 3), Map("a" -> 1)), + Row(2, Seq.empty[Int], Map.empty[String, Int]), + Row(3, Seq(4), Map("b" -> 2, "c" -> 3)) + )) + + // Verify cache was used + assert(cached.storageLevel.useMemory) + } + } + + test("columnar input with deeply nested structures from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + // Create a deeply nested structure: Array[Struct[Map[String, Array[Int]]]] + val df = Seq( + (1, Seq( + (Map("x" -> Seq(1, 2)), "data1"), + (Map("y" -> Seq(3, 4, 5)), "data2") + )), + (2, Seq( + (Map("z" -> Seq(6)), "data3") + )) + ).toDF("id", "deep_nested") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Seq( + Row(Map("x" -> Seq(1, 2)), "data1"), + Row(Map("y" -> Seq(3, 4, 5)), "data2") + )), + Row(2, Seq( + Row(Map("z" -> Seq(6)), "data3") + )) + )) + + // Verify operations work on deeply nested data + val result = cached.filter($"id" === 1) + assert(result.count() === 1) + } + } + + test("columnar input with mixed primitive and complex types from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, "name1", 100L, Seq(1, 2, 3), Map("k1" -> "v1"), ("nested", 99)), + (2, "name2", 200L, Seq(4, 5), Map("k2" -> "v2"), ("nested2", 88)), + (3, "name3", 300L, Seq(6), Map("k3" -> "v3"), ("nested3", 77)) + ).toDF("id", "name", "value", "array_col", "map_col", "struct_col") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, "name1", 100L, Seq(1, 2, 3), Map("k1" -> "v1"), Row("nested", 99)), + Row(2, "name2", 200L, Seq(4, 5), Map("k2" -> "v2"), Row("nested2", 88)), + Row(3, "name3", 300L, Seq(6), Map("k3" -> "v3"), Row("nested3", 77)) + )) + + // Verify column projection works + val projected = cached.select("id", "array_col", "struct_col") + checkAnswer(projected, Seq( + Row(1, Seq(1, 2, 3), Row("nested", 99)), + Row(2, Seq(4, 5), Row("nested2", 88)), + Row(3, Seq(6), Row("nested3", 77)) + )) + } + } + + test("columnar input with large complex types dataset from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + // Create a larger dataset with complex types + val df = (1 to 1000).map { i => + (i, Seq(i, i * 2, i * 3), Map(s"key$i" -> i * 10), (s"struct$i", i * 100)) + }.toDF("id", "array_col", "map_col", "struct_col") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + + // Verify a filtered subset + val filtered = cached.filter($"id" > 990) + assert(filtered.count() === 10) + + // Verify content of filtered data + val result = filtered.collect().sortBy(_.getInt(0)) + assert(result.length === 10) + assert(result(0).getInt(0) === 991) + assert(result(0).getAs[Seq[Int]](1) === Seq(991, 1982, 2973)) + } + } + + test("columnar input with vectorized reader and complex types") { + withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> "true") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, Seq(1, 2, 3), ("a", 10)), + (2, Seq(4, 5, 6), ("b", 20)), + (3, Seq(7, 8, 9), ("c", 30)) + ).toDF("id", "array_col", "struct_col") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache with vectorized reader enabled + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Seq(1, 2, 3), Row("a", 10)), + Row(2, Seq(4, 5, 6), Row("b", 20)), + Row(3, Seq(7, 8, 9), Row("c", 30)) + )) + + // Verify vectorized reader is used + val plan = cached.queryExecution.executedPlan + val inMemoryScan = plan.collectFirst { + case scan: InMemoryTableScanExec => scan + } + assert(inMemoryScan.isDefined) + assert(inMemoryScan.get.supportsColumnar) + } + } + } + + test("InternalRow path (readValueFromVector) handles all supported data types") { + // Exercises every explicit type arm in readValueFromVector via the InternalRow fallback path + // (CACHE_VECTORIZED_READER_ENABLED=false, set at suite level). Each case verifies a full + // cache write -> read roundtrip including both non-null and null values. + + // --- Primitive types --- + + // BooleanType: BitVector.get(i) != 0 + val boolDf = Seq(Some(true), None, Some(false)).toDF("v") + boolDf.cache() + checkAnswer(boolDf, Seq(Row(true), Row(null), Row(false))) + boolDf.unpersist() + + // ByteType: TinyIntVector.get(i) + val byteDf = Seq(Some(1.toByte), None, Some(10.toByte)).toDF("v") + byteDf.cache() + checkAnswer(byteDf, Seq(Row(1.toByte), Row(null), Row(10.toByte))) + byteDf.unpersist() + + // ShortType: SmallIntVector.get(i) + val shortDf = Seq(Some(1.toShort), None, Some(100.toShort)).toDF("v") + shortDf.cache() + checkAnswer(shortDf, Seq(Row(1.toShort), Row(null), Row(100.toShort))) + shortDf.unpersist() + + // IntegerType: IntVector.get(i) + val intDf = Seq(Some(42), None, Some(-7)).toDF("v") + intDf.cache() + checkAnswer(intDf, Seq(Row(42), Row(null), Row(-7))) + intDf.unpersist() + + // LongType: BigIntVector.get(i) + val longDf = Seq(Some(100L), None, Some(-50L)).toDF("v") + longDf.cache() + checkAnswer(longDf, Seq(Row(100L), Row(null), Row(-50L))) + longDf.unpersist() + + // FloatType: Float4Vector.get(i) + val floatDf = Seq(Some(3.14f), None, Some(-1.0f)).toDF("v") + floatDf.cache() + checkAnswer(floatDf, Seq(Row(3.14f), Row(null), Row(-1.0f))) + floatDf.unpersist() + + // DoubleType: Float8Vector.get(i) + val doubleDf = Seq(Some(2.718), None, Some(-1.0)).toDF("v") + doubleDf.cache() + checkAnswer(doubleDf, Seq(Row(2.718), Row(null), Row(-1.0))) + doubleDf.unpersist() + + // --- String and Binary types --- + + // StringType: VarCharVector.get(i) -> UTF8String.fromBytes + val stringDf = Seq(Some("hello"), None, Some("world")).toDF("v") + stringDf.cache() + checkAnswer(stringDf, Seq(Row("hello"), Row(null), Row("world"))) + stringDf.unpersist() + + // BinaryType: VarBinaryVector.get(i) + val bytes1 = "hello".getBytes("UTF-8") + val bytes2 = "world".getBytes("UTF-8") + val binaryDf = Seq(bytes1, bytes2).toDF("v") + binaryDf.cache() + val binaryResult = binaryDf.collect() + assert(binaryResult(0).getAs[Array[Byte]](0) sameElements bytes1) + assert(binaryResult(1).getAs[Array[Byte]](0) sameElements bytes2) + binaryDf.unpersist() + + // DecimalType (compact, precision <= 18): fast path reads unscaled long from Arrow buffer + val decDf = Seq(Some(BigDecimal("123.45")), None, Some(BigDecimal("678.90"))).toDF("v") + decDf.cache() + checkAnswer(decDf, Seq(Row(BigDecimal("123.45")), Row(null), Row(BigDecimal("678.90")))) + decDf.unpersist() + + // DecimalType (compact, negative values): verifies sign-bit correctness when reading + // lower 8 bytes of Arrow's 128-bit little-endian two's-complement buffer as signed Long + val negDecData = Seq( + new java.math.BigDecimal("-123.45"), + new java.math.BigDecimal("0.00"), + new java.math.BigDecimal("-999999.99")) + val negDecDf = spark.createDataFrame( + spark.sparkContext.parallelize(negDecData.map(Row(_))), + StructType(Seq(StructField("v", DecimalType(10, 2))))) + negDecDf.cache() + checkAnswer(negDecDf, negDecData.map(d => Row(d))) + negDecDf.unpersist() + + // DecimalType (wide, precision > 18): slow path via DecimalVector.getObject -> BigDecimal + val wideDecData = Seq( + new java.math.BigDecimal("12345678901234567890.1234567890"), + new java.math.BigDecimal("-99999999999999999999.9999999999"), + new java.math.BigDecimal("0.0000000001")) + val wideDecDf = spark.createDataFrame( + spark.sparkContext.parallelize(wideDecData.map(Row(_))), + StructType(Seq(StructField("v", DecimalType(30, 10))))) + wideDecDf.cache() + checkAnswer(wideDecDf, wideDecData.map(d => Row(d))) + wideDecDf.unpersist() + + // --- Date and time types --- + + // DateType: DateDayVector.get(i) (days since epoch) + val dateDf = Seq(Some(Date.valueOf("2020-01-01")), None, Some(Date.valueOf("2025-12-31"))) + .toDF("v") + dateDf.cache() + checkAnswer(dateDf, + Seq(Row(Date.valueOf("2020-01-01")), Row(null), Row(Date.valueOf("2025-12-31")))) + dateDf.unpersist() + + // TimestampType: TimeStampMicroTZVector.get(i) (microseconds since epoch) + val ts1 = Timestamp.valueOf("2020-01-01 12:00:00") + val ts2 = Timestamp.valueOf("2025-06-15 00:00:00") + val tsDf = Seq(Some(ts1), None, Some(ts2)).toDF("v") + tsDf.cache() + checkAnswer(tsDf, Seq(Row(ts1), Row(null), Row(ts2))) + tsDf.unpersist() + + // TimestampNTZType: TimeStampMicroVector.get(i) (microseconds, no timezone) + val ldt1 = LocalDateTime.of(2020, 1, 1, 12, 0) + val ldt2 = LocalDateTime.of(2025, 6, 15, 0, 0) + val tsNtzDf = Seq(Some(ldt1), None, Some(ldt2)).toDF("v") + tsNtzDf.cache() + checkAnswer(tsNtzDf, Seq(Row(ldt1), Row(null), Row(ldt2))) + tsNtzDf.unpersist() + + // --- Interval types --- + + // YearMonthIntervalType: IntervalYearVector.get(i) (months) + val ymiSql = "SELECT INTERVAL '1-1' YEAR TO MONTH AS ymi" + val ymiDf = spark.sql(ymiSql) + ymiDf.cache() + checkAnswer(ymiDf, spark.sql(ymiSql)) + ymiDf.unpersist() + + // DayTimeIntervalType: DurationVector.get(int) returns ArrowBuf; must use static form + val dtiSql = "SELECT INTERVAL '1' DAY AS dti" + val dtiDf = spark.sql(dtiSql) + dtiDf.cache() + checkAnswer(dtiDf, spark.sql(dtiSql)) + dtiDf.unpersist() + + // TimeType: TimeNanoVector.get(i) (nanoseconds since midnight) + val timeDf = Seq(LocalTime.of(12, 30, 45), LocalTime.of(0, 0, 0)).toDF("t") + timeDf.cache() + checkAnswer(timeDf, Seq(Row(LocalTime.of(12, 30, 45)), Row(LocalTime.of(0, 0, 0)))) + timeDf.unpersist() + + // CalendarIntervalType: ArrowColumnVector.getInterval(i) (IntervalMonthDayNanoVector) + val interval = new CalendarInterval(1, 2, 3000000L) // 1 month, 2 days, 3 ms + val ciSchema = StructType(Seq(StructField("ci", CalendarIntervalType, nullable = true))) + val ciDf = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(interval), Row(null))), ciSchema) + ciDf.cache() + checkAnswer(ciDf, Seq(Row(interval), Row(null))) + ciDf.unpersist() + + // --- Null type --- + + // NullType: row.setNullAt without dispatching into readValueFromVector + val nullSchema = StructType(Seq(StructField("n", NullType, nullable = true))) + val nullDf = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(null), Row(null))), nullSchema) + nullDf.cache() + checkAnswer(nullDf, Seq(Row(null), Row(null))) + nullDf.unpersist() + + // --- Complex types --- + + // ArrayType: ArrowColumnVector.getArray(i) (ListVector) + val arrayDf = Seq(Seq(1, 2, 3), Seq(4, 5, 6)).toDF("v") + arrayDf.cache() + checkAnswer(arrayDf, Seq(Row(Seq(1, 2, 3)), Row(Seq(4, 5, 6)))) + arrayDf.unpersist() + + // StructType: ArrowColumnVector.getStruct(i) (StructVector) + val structSql = + "SELECT named_struct('a', 1, 'b', 'x') AS v " + + "UNION ALL SELECT named_struct('a', 2, 'b', 'y') AS v" + val structDf = spark.sql(structSql) + structDf.cache() + checkAnswer(structDf, spark.sql(structSql)) + structDf.unpersist() + + // MapType: ArrowColumnVector.getMap(i) (MapVector) + val mapDf = Seq(Map(1 -> "a"), Map(2 -> "b")).toDF("v") + mapDf.cache() + checkAnswer(mapDf, Seq(Row(Map(1 -> "a")), Row(Map(2 -> "b")))) + mapDf.unpersist() + + // UserDefinedType: dispatches to readValueFromVector with udt.sqlType (ArrayType(DoubleType)) + val point = new ExamplePoint(1.0, 2.0) + val udtSchema = StructType(Seq(StructField("p", new ExamplePointUDT(), nullable = true))) + val udtDf = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(point), Row(null))), udtSchema) + udtDf.cache() + checkAnswer(udtDf, Seq(Row(point), Row(null))) + udtDf.unpersist() + + // VariantType: ArrowColumnVector.getVariant(i) (StructVector) + val variantDf = spark.sql("SELECT parse_json('{\"a\":1}') AS v") + variantDf.cache() + checkAnswer(variantDf.selectExpr("to_json(v)"), Seq(Row("{\"a\":1}"))) + variantDf.unpersist() + } + + // Helper: cache a single-column DataFrame (row path) and return its ArrowCachedBatch stats. + // Stats layout per column: [lowerBound(0), upperBound(1), nullCount(2), rowCount(3), size(4)]. + private def cachedStats(df: org.apache.spark.sql.DataFrame) + : org.apache.spark.sql.catalyst.InternalRow = { + df.count() // trigger cache population + val relation = df.queryExecution.executedPlan.collectFirst { + case scan: InMemoryTableScanExec => scan.relation + }.get + relation.cacheBuilder.cachedColumnBuffers.first().asInstanceOf[ArrowCachedBatch].stats + } + + // Helper: creates a single-column, single-partition DataFrame backed by an RDD. + // LocalRelation can split across multiple partitions, causing cachedStats to see only the first + // partition's stats. sc.parallelize(data, numSlices=1) forces exactly one partition. + private def singlePartDf(values: Seq[Any], dt: DataType): org.apache.spark.sql.DataFrame = + spark.createDataFrame( + spark.sparkContext.parallelize(values.map(v => Row(v)), 1), + StructType(Seq(StructField("v", dt, nullable = true)))) + + test("createColumnStats returns the correct ColumnStats subclass for each supported type") { + // Direct unit test: verify the stats class dispatched for each Spark type, which determines + // whether partition pruning via min/max bounds is enabled. + + // Orderable types: createColumnStats returns a stats class that tracks min/max bounds. + assert(ArrowCachedBatchSerializer.createColumnStats(BooleanType) + .isInstanceOf[BooleanColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(ByteType).isInstanceOf[ByteColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(ShortType).isInstanceOf[ShortColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(IntegerType).isInstanceOf[IntColumnStats]) + // DateType is stored as Int (days since epoch) -> IntColumnStats + assert(ArrowCachedBatchSerializer.createColumnStats(DateType).isInstanceOf[IntColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(LongType).isInstanceOf[LongColumnStats]) + // TimestampType/NTZ stored as Long (microseconds) -> LongColumnStats + assert(ArrowCachedBatchSerializer.createColumnStats(TimestampType) + .isInstanceOf[LongColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(TimestampNTZType) + .isInstanceOf[LongColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(FloatType).isInstanceOf[FloatColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(DoubleType).isInstanceOf[DoubleColumnStats]) + // StringType (all collations) -> StringColumnStats with collation-aware semantic comparison + assert(ArrowCachedBatchSerializer.createColumnStats(StringType).isInstanceOf[StringColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats( + new StringType(1)).isInstanceOf[StringColumnStats]) // collationId 1 = UTF8_LCASE + assert(ArrowCachedBatchSerializer.createColumnStats( + StringType("UNICODE")).isInstanceOf[StringColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(DecimalType(10, 2)) + .isInstanceOf[DecimalColumnStats]) + // YearMonthIntervalType stored as Int (months) -> IntColumnStats + assert(ArrowCachedBatchSerializer.createColumnStats( + YearMonthIntervalType()).isInstanceOf[IntColumnStats]) + // DayTimeIntervalType stored as Long (microseconds) -> LongColumnStats + assert(ArrowCachedBatchSerializer.createColumnStats( + DayTimeIntervalType()).isInstanceOf[LongColumnStats]) + // TimeType stored as Long (nanoseconds) -> LongColumnStats + assert(ArrowCachedBatchSerializer.createColumnStats(TimeType(6)).isInstanceOf[LongColumnStats]) + + // Non-orderable types: createColumnStats returns a stats class with null bounds. + assert(ArrowCachedBatchSerializer.createColumnStats(BinaryType).isInstanceOf[BinaryColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats( + CalendarIntervalType).isInstanceOf[IntervalColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(VariantType) + .isInstanceOf[VariantColumnStats]) + // Geometry/Geography use a BinaryView-aware collector (their physical value is a BinaryView). + assert(ArrowCachedBatchSerializer.createColumnStats(GeometryType(4326)) + .isInstanceOf[GeoColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(GeographyType(4326)) + .isInstanceOf[GeoColumnStats]) + + // Complex types and UDT: no natural ordering -> ObjectColumnStats (null bounds). + assert(ArrowCachedBatchSerializer.createColumnStats( + ArrayType(IntegerType)).isInstanceOf[ObjectColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats( + StructType(Seq(StructField("a", IntegerType)))).isInstanceOf[ObjectColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats( + MapType(StringType, IntegerType)).isInstanceOf[ObjectColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats( + new ExamplePointUDT()).isInstanceOf[ObjectColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(NullType).isInstanceOf[ObjectColumnStats]) + } + + test("GeoColumnStats reads a BinaryView value from a generic row") { + // The Catalyst physical value for Geometry/Geography is a BinaryView. A row-based reader or + // direct serializer use can hand the stats collector a GenericInternalRow holding that + // BinaryView; reading it via getBinary (as BinaryColumnStats does) would throw + // ClassCastException. GeoColumnStats must read it via getBinaryView and account for its size. + val stats = ArrowCachedBatchSerializer.createColumnStats(GeometryType(4326)) + assert(stats.isInstanceOf[GeoColumnStats]) + val view = BinaryView.fromBytes(Array[Byte](1, 2, 3, 4, 5)) + val nonNull = new GenericInternalRow(Array[Any](view)) + val nullRow = new GenericInternalRow(Array[Any](null)) + stats.gatherStats(nonNull, 0) + stats.gatherStats(nullRow, 0) + val collected = stats.collectedStatistics + // Layout: [lowerBound, upperBound, nullCount, count, sizeInBytes]; no min/max for geo. + // count is the total row count (null + non-null), matching ColumnStats.gatherNullStats. + assert(collected(0) == null && collected(1) == null, "geo has no min/max bounds") + assert(collected(2) == 1, "one null value") + assert(collected(3) == 2, "two rows total (one non-null, one null)") + assert(collected(4).asInstanceOf[Long] >= 5L, "size must account for the 5-byte payload") + } + + test("row path stats: orderable types produce correct min/max bounds") { + // DataFrames use the row path (InternalRowToArrowCachedBatchIterator), exercising + // createColumnStats + buildStatisticsFromCollectors. singlePartDf ensures all values land + // in one cached batch so cachedStats.first() sees the global min and max. + + // BooleanType: lower=false, upper=true + val boolDf = singlePartDf(Seq(false, true), BooleanType).cache() + val boolStats = cachedStats(boolDf) + assert(!boolStats.isNullAt(0) && !boolStats.isNullAt(1)) + assert(!boolStats.getBoolean(0) && boolStats.getBoolean(1)) + boolDf.unpersist() + + // ByteType: lower=1, upper=10 + val byteDf = singlePartDf(Seq(1.toByte, 10.toByte), ByteType).cache() + val byteStats = cachedStats(byteDf) + assert(!byteStats.isNullAt(0) && !byteStats.isNullAt(1)) + assert(byteStats.getByte(0) == 1.toByte && byteStats.getByte(1) == 10.toByte) + byteDf.unpersist() + + // ShortType: lower=1, upper=10 + val shortDf = singlePartDf(Seq(1.toShort, 10.toShort), ShortType).cache() + val shortStats = cachedStats(shortDf) + assert(!shortStats.isNullAt(0) && !shortStats.isNullAt(1)) + assert(shortStats.getShort(0) == 1.toShort && shortStats.getShort(1) == 10.toShort) + shortDf.unpersist() + + // IntegerType: lower=1, upper=10 + val intDf = singlePartDf(Seq(1, 10), IntegerType).cache() + val intStats = cachedStats(intDf) + assert(!intStats.isNullAt(0) && !intStats.isNullAt(1)) + assert(intStats.getInt(0) == 1 && intStats.getInt(1) == 10) + intDf.unpersist() + + // DateType: stored as Int (days since epoch); 2020-01-01 < 2025-01-01 + val dateDf = singlePartDf( + Seq(Date.valueOf("2020-01-01"), Date.valueOf("2025-01-01")), DateType).cache() + val dateStats = cachedStats(dateDf) + assert(!dateStats.isNullAt(0) && !dateStats.isNullAt(1)) + assert(dateStats.getInt(0) < dateStats.getInt(1)) + dateDf.unpersist() + + // LongType: lower=1L, upper=10L + val longDf = singlePartDf(Seq(1L, 10L), LongType).cache() + val longStats = cachedStats(longDf) + assert(!longStats.isNullAt(0) && !longStats.isNullAt(1)) + assert(longStats.getLong(0) == 1L && longStats.getLong(1) == 10L) + longDf.unpersist() + + // TimestampType: stored as Long (microseconds since epoch); 2020 < 2025 + val tsDf = singlePartDf( + Seq(Timestamp.valueOf("2020-01-01 00:00:00"), Timestamp.valueOf("2025-01-01 00:00:00")), + TimestampType).cache() + val tsStats = cachedStats(tsDf) + assert(!tsStats.isNullAt(0) && !tsStats.isNullAt(1)) + assert(tsStats.getLong(0) < tsStats.getLong(1)) + tsDf.unpersist() + + // TimestampNTZType: stored as Long (microseconds since epoch); 2020 < 2025 + val tsNtzDf = singlePartDf( + Seq(LocalDateTime.of(2020, 1, 1, 0, 0), LocalDateTime.of(2025, 1, 1, 0, 0)), + TimestampNTZType).cache() + val tsNtzStats = cachedStats(tsNtzDf) + assert(!tsNtzStats.isNullAt(0) && !tsNtzStats.isNullAt(1)) + assert(tsNtzStats.getLong(0) < tsNtzStats.getLong(1)) + tsNtzDf.unpersist() + + // FloatType: NaN is included but IEEE 754 comparisons with NaN are always false, + // so NaN never updates min/max; lower=1.0f, upper=10.0f + val floatDf = singlePartDf(Seq(1.0f, Float.NaN, 10.0f), FloatType).cache() + val floatStats = cachedStats(floatDf) + assert(!floatStats.isNullAt(0) && !floatStats.isNullAt(1)) + assert(floatStats.getFloat(0) == 1.0f && floatStats.getFloat(1) == 10.0f) + floatDf.unpersist() + + // DoubleType: same NaN-exclusion behavior via IEEE 754; lower=1.0, upper=10.0 + val doubleDf = singlePartDf(Seq(1.0, Double.NaN, 10.0), DoubleType).cache() + val doubleStats = cachedStats(doubleDf) + assert(!doubleStats.isNullAt(0) && !doubleStats.isNullAt(1)) + assert(doubleStats.getDouble(0) == 1.0 && doubleStats.getDouble(1) == 10.0) + doubleDf.unpersist() + + // StringType (UTF8_BINARY): "apple" < "zebra" in binary order + val stringDf = singlePartDf(Seq("apple", "zebra"), StringType).cache() + val stringStats = cachedStats(stringDf) + assert(!stringStats.isNullAt(0) && !stringStats.isNullAt(1)) + assert(stringStats.getUTF8String(0).toString == "apple") + assert(stringStats.getUTF8String(1).toString == "zebra") + stringDf.unpersist() + + // Collated StringType (UTF8_LCASE): semantic min/max uses case-insensitive comparison. + // "Apple" and "zebra": case-insensitively "apple" < "zebra", so lower="Apple", upper="zebra". + val collatedStringType = new StringType(1) // collationId 1 = UTF8_LCASE + val collatedSchema = StructType(Seq(StructField("v", collatedStringType, nullable = true))) + val collatedDf = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row("Apple"), Row("zebra")), 1), + collatedSchema).cache() + val collatedStats = cachedStats(collatedDf) + assert(!collatedStats.isNullAt(0), "lower bound should not be null for collated StringType") + assert(!collatedStats.isNullAt(1), "upper bound should not be null for collated StringType") + assert(collatedStats.getUTF8String(0).toString == "Apple") // semantic min + assert(collatedStats.getUTF8String(1).toString == "zebra") // semantic max + collatedDf.unpersist() + + // DecimalType(10,2): lower=1.23, upper=9.87 + val decimalDf = singlePartDf( + Seq(new java.math.BigDecimal("1.23"), new java.math.BigDecimal("9.87")), + DecimalType(10, 2)).cache() + val decimalStats = cachedStats(decimalDf) + assert(!decimalStats.isNullAt(0) && !decimalStats.isNullAt(1)) + assert(decimalStats.getDecimal(0, 10, 2).compareTo(decimalStats.getDecimal(1, 10, 2)) < 0) + decimalDf.unpersist() + + // YearMonthIntervalType: stored as Int (months); Period.of(1,0,0)=12mo < Period.of(2,0,0)=24mo + val ymiDf = singlePartDf( + Seq(Period.of(1, 0, 0), Period.of(2, 0, 0)), YearMonthIntervalType()).cache() + val ymiStats = cachedStats(ymiDf) + assert(!ymiStats.isNullAt(0) && !ymiStats.isNullAt(1)) + assert(ymiStats.getInt(0) < ymiStats.getInt(1)) + ymiDf.unpersist() + + // DayTimeIntervalType: stored as Long (microseconds); 1 day < 2 days + val dtiDf = singlePartDf( + Seq(Duration.ofDays(1), Duration.ofDays(2)), DayTimeIntervalType()).cache() + val dtiStats = cachedStats(dtiDf) + assert(!dtiStats.isNullAt(0) && !dtiStats.isNullAt(1)) + assert(dtiStats.getLong(0) < dtiStats.getLong(1)) + dtiDf.unpersist() + + // TimeType: stored as Long (nanoseconds); 08:00 < 20:00 + val timeDf = singlePartDf( + Seq(LocalTime.of(8, 0, 0), LocalTime.of(20, 0, 0)), TimeType(6)).cache() + val timeStats = cachedStats(timeDf) + assert(!timeStats.isNullAt(0) && !timeStats.isNullAt(1)) + assert(timeStats.getLong(0) < timeStats.getLong(1)) + timeDf.unpersist() + } + + test("row path stats: non-orderable types produce null lower and upper bounds") { + // Verifies that types without natural ordering return null bounds so that partition pruning + // is safely disabled for them, preventing incorrect data exclusion. + def assertNullBounds(df: org.apache.spark.sql.DataFrame): Unit = { + val stats = cachedStats(df) + assert(stats.isNullAt(0), "lower bound should be null for non-orderable type") + assert(stats.isNullAt(1), "upper bound should be null for non-orderable type") + df.unpersist() + } + + // BinaryType: no natural total ordering + assertNullBounds(Seq(Array[Byte](1, 2), Array[Byte](3, 4)).toDF("v").cache()) + + // CalendarIntervalType: unordered composite (months + days + nanoseconds) + val ciSchema = StructType(Seq(StructField("v", CalendarIntervalType, nullable = true))) + assertNullBounds(spark.createDataFrame( + spark.sparkContext.parallelize(Seq( + Row(new CalendarInterval(1, 2, 3000000L)), + Row(new CalendarInterval(2, 0, 0L)))), + ciSchema).cache()) + + // ArrayType: no natural ordering + assertNullBounds(spark.sql( + "SELECT array(1, 2) AS v UNION ALL SELECT array(3, 4) AS v" + ).cache()) + + // StructType: no natural ordering + assertNullBounds(spark.sql( + "SELECT named_struct('i', 1, 's', 'a') AS v " + + "UNION ALL SELECT named_struct('i', 2, 's', 'b') AS v" + ).cache()) + + // MapType: no natural ordering + assertNullBounds(spark.sql( + "SELECT map(1, 'a') AS v UNION ALL SELECT map(2, 'b') AS v" + ).cache()) + + // UserDefinedType: no natural ordering + val udtSchema = StructType(Seq(StructField("v", new ExamplePointUDT(), nullable = true))) + assertNullBounds(spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)), Row(null))), + udtSchema).cache()) + + // NullType: all values are null by definition + val nullSchema = StructType(Seq(StructField("v", NullType, nullable = true))) + assertNullBounds(spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(null), Row(null))), + nullSchema).cache()) + + // VariantType: no natural ordering + assertNullBounds(spark.sql("SELECT parse_json('{\"k\":1}') AS v").cache()) + } + + test("row path stats: all-NaN Float/Double column produces inverted sentinel bounds") { + // FloatColumnStats and DoubleColumnStats initialize upper=MinValue, lower=MaxValue as + // sentinels. IEEE 754 comparisons with NaN are always false, so NaN never beats either + // sentinel. When every value is NaN, the sentinels are returned unchanged: lower=MaxValue, + // upper=MinValue (lower > upper). This differs from the Arrow path, which returns null bounds + // for all-NaN input (because calculateMinMaxFloat/Double explicitly skips NaN with !_.isNaN + // and returns (null, null) when hasValue stays false). + val floatDf = singlePartDf(Seq(Float.NaN), FloatType).cache() + val floatStats = cachedStats(floatDf) + assert(!floatStats.isNullAt(0), + "FloatType lower should not be null for all-NaN (sentinel used)") + assert(!floatStats.isNullAt(1), + "FloatType upper should not be null for all-NaN (sentinel used)") + assert(floatStats.getFloat(0) == Float.MaxValue, + s"FloatType lower expected Float.MaxValue (sentinel), got ${floatStats.getFloat(0)}") + assert(floatStats.getFloat(1) == Float.MinValue, + s"FloatType upper expected Float.MinValue (sentinel), got ${floatStats.getFloat(1)}") + floatDf.unpersist() + + val doubleDf = singlePartDf(Seq(Double.NaN), DoubleType).cache() + val doubleStats = cachedStats(doubleDf) + assert(!doubleStats.isNullAt(0), + "DoubleType lower should not be null for all-NaN (sentinel used)") + assert(!doubleStats.isNullAt(1), + "DoubleType upper should not be null for all-NaN (sentinel used)") + assert(doubleStats.getDouble(0) == Double.MaxValue, + s"DoubleType lower expected Double.MaxValue (sentinel), got ${doubleStats.getDouble(0)}") + assert(doubleStats.getDouble(1) == Double.MinValue, + s"DoubleType upper expected Double.MinValue (sentinel), got ${doubleStats.getDouble(1)}") + doubleDf.unpersist() + } + + test("collectStatistics produces correct min/max bounds for all orderable types") { + // Direct unit test of ArrowCachedBatchSerializer.collectStatistics, which is invoked whenever + // the input ColumnarBatch contains ArrowColumnVector columns (zero-copy path in + // ColumnarBatchToArrowCachedBatchIterator). Three rows [low, mid, high] ensure min/max are + // correctly identified for each type. + val serializer = new ArrowCachedBatchSerializer() + + val schema = Seq( + AttributeReference("bool_col", BooleanType)(), // BitVector + AttributeReference("byte_col", ByteType)(), // TinyIntVector + AttributeReference("short_col", ShortType)(), // SmallIntVector + AttributeReference("float_col", FloatType)(), // Float4Vector + AttributeReference("double_col", DoubleType)(), // Float8Vector + AttributeReference("date_col", DateType)(), // DateDayVector (days since epoch) + AttributeReference("ts_col", TimestampType)(), // TimeStampMicroTZVector (microseconds) + AttributeReference("ts_ntz_col", TimestampNTZType)(),// TimeStampMicroVector (microseconds) + AttributeReference("int_col", IntegerType)(), // IntVector (standalone) + AttributeReference("long_col", LongType)(), // BigIntVector (standalone) + AttributeReference("decimal_col", DecimalType(10, 2))(), // DecimalVector + AttributeReference("ymi_col", YearMonthIntervalType())(), // IntervalYearVector (months) + AttributeReference("dti_col", DayTimeIntervalType())(), // DurationVector (microseconds) + AttributeReference("time_col", TimeType(6))() // TimeNanoVector (nanoseconds) + ) + val sparkSchema = StructType(schema.map(a => StructField(a.name, a.dataType))) + val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false) + val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) + + try { + root.allocateNew() + val boolVector = root.getVector("bool_col").asInstanceOf[BitVector] + val byteVector = root.getVector("byte_col").asInstanceOf[TinyIntVector] + val shortVector = root.getVector("short_col").asInstanceOf[SmallIntVector] + val floatVector = root.getVector("float_col").asInstanceOf[Float4Vector] + val doubleVector = root.getVector("double_col").asInstanceOf[Float8Vector] + val dateVector = root.getVector("date_col").asInstanceOf[DateDayVector] + val tsVector = root.getVector("ts_col").asInstanceOf[TimeStampMicroTZVector] + val tsNtzVector = root.getVector("ts_ntz_col").asInstanceOf[TimeStampMicroVector] + val intVector = root.getVector("int_col").asInstanceOf[IntVector] + val longVector = root.getVector("long_col").asInstanceOf[BigIntVector] + val decimalVector = root.getVector("decimal_col").asInstanceOf[DecimalVector] + val ymiVector = root.getVector("ymi_col") + .asInstanceOf[org.apache.arrow.vector.IntervalYearVector] + val dtiVector = root.getVector("dti_col") + .asInstanceOf[org.apache.arrow.vector.DurationVector] + val timeVector = root.getVector("time_col").asInstanceOf[TimeNanoVector] + + // Row 0: low values + boolVector.setSafe(0, 0) // false + byteVector.setSafe(0, 1.toByte) + shortVector.setSafe(0, 100.toShort) + floatVector.setSafe(0, 1.0f) + doubleVector.setSafe(0, 1.0) + dateVector.setSafe(0, 18262) // 2020-01-01 + tsVector.setSafe(0, 1577836800000000L) // 2020-01-01 00:00:00 UTC in microseconds + tsNtzVector.setSafe(0, 1577836800000000L) + intVector.setSafe(0, 1) + longVector.setSafe(0, 1L) + decimalVector.setSafe(0, new java.math.BigDecimal("1.23")) + ymiVector.setSafe(0, 12) // 1 year = 12 months + dtiVector.setSafe(0, 86400000000L) // 1 day in microseconds + timeVector.setSafe(0, 28800000000000L) // 08:00:00 in nanoseconds + + // Row 1: mid values -- Float/Double use NaN to verify NaN is excluded from min/max + boolVector.setSafe(1, 1) // true -- becomes the max + byteVector.setSafe(1, 5.toByte) + shortVector.setSafe(1, 500.toShort) + floatVector.setSafe(1, Float.NaN) // NaN: must not affect lower=1.0f or upper=10.0f + doubleVector.setSafe(1, Double.NaN) // NaN: must not affect lower=1.0 or upper=10.0 + dateVector.setSafe(1, 19000) + tsVector.setSafe(1, 1700000000000000L) + tsNtzVector.setSafe(1, 1700000000000000L) + intVector.setSafe(1, 5) + longVector.setSafe(1, 5L) + decimalVector.setSafe(1, new java.math.BigDecimal("5.55")) + ymiVector.setSafe(1, 18) // 1.5 years = 18 months + dtiVector.setSafe(1, 172800000000L) // 2 days in microseconds + timeVector.setSafe(1, 43200000000000L) // 12:00:00 in nanoseconds + + // Row 2: high values + boolVector.setSafe(2, 0) // false again (3 rows; bool max stays true from row 1) + byteVector.setSafe(2, 10.toByte) + shortVector.setSafe(2, 1000.toShort) + floatVector.setSafe(2, 10.0f) + doubleVector.setSafe(2, 10.0) + dateVector.setSafe(2, 20000) + tsVector.setSafe(2, 1800000000000000L) + tsNtzVector.setSafe(2, 1800000000000000L) + intVector.setSafe(2, 10) + longVector.setSafe(2, 10L) + decimalVector.setSafe(2, new java.math.BigDecimal("9.87")) + ymiVector.setSafe(2, 24) // 2 years = 24 months + dtiVector.setSafe(2, 259200000000L) // 3 days in microseconds + timeVector.setSafe(2, 72000000000000L) // 20:00:00 in nanoseconds + + root.setRowCount(3) + + val stats = ArrowCachedBatchSerializer.collectStatistics(root, schema) + + // Stats layout: [lower(0), upper(1), nullCount(2), rowCount(3), sizeInBytes(4)] per column. + // col0 BooleanType (offset 0): lower=false, upper=true + assert(!stats.getBoolean(0), s"BooleanType lower expected false, got ${stats.getBoolean(0)}") + assert(stats.getBoolean(1), s"BooleanType upper expected true, got ${stats.getBoolean(1)}") + + // col1 ByteType (offset 5): lower=1, upper=10 + assert(stats.getByte(5) == 1.toByte, s"ByteType lower=${stats.getByte(5)}") + assert(stats.getByte(6) == 10.toByte, s"ByteType upper=${stats.getByte(6)}") + + // col2 ShortType (offset 10): lower=100, upper=1000 + assert(stats.getShort(10) == 100.toShort, s"ShortType lower=${stats.getShort(10)}") + assert(stats.getShort(11) == 1000.toShort, s"ShortType upper=${stats.getShort(11)}") + + // col3 FloatType (offset 15): lower=1.0f, upper=10.0f + assert(stats.getFloat(15) == 1.0f, s"FloatType lower=${stats.getFloat(15)}") + assert(stats.getFloat(16) == 10.0f, s"FloatType upper=${stats.getFloat(16)}") + + // col4 DoubleType (offset 20): lower=1.0, upper=10.0 + assert(stats.getDouble(20) == 1.0, s"DoubleType lower=${stats.getDouble(20)}") + assert(stats.getDouble(21) == 10.0, s"DoubleType upper=${stats.getDouble(21)}") + + // col5 DateType (offset 25): lower=18262 (2020-01-01), upper=20000 + assert(stats.getInt(25) == 18262, s"DateType lower=${stats.getInt(25)}") + assert(stats.getInt(26) == 20000, s"DateType upper=${stats.getInt(26)}") + + // col6 TimestampType (offset 30): lower < upper (microseconds since epoch) + assert(stats.getLong(30) == 1577836800000000L, + s"TimestampType lower=${stats.getLong(30)}") + assert(stats.getLong(31) == 1800000000000000L, + s"TimestampType upper=${stats.getLong(31)}") + + // col7 TimestampNTZType (offset 35): lower < upper (microseconds, no timezone) + assert(stats.getLong(35) == 1577836800000000L, + s"TimestampNTZType lower=${stats.getLong(35)}") + assert(stats.getLong(36) == 1800000000000000L, + s"TimestampNTZType upper=${stats.getLong(36)}") + + // col8 IntegerType (offset 40): lower=1, upper=10 + assert(stats.getInt(40) == 1, s"IntegerType lower=${stats.getInt(40)}") + assert(stats.getInt(41) == 10, s"IntegerType upper=${stats.getInt(41)}") + + // col9 LongType (offset 45): lower=1L, upper=10L + assert(stats.getLong(45) == 1L, s"LongType lower=${stats.getLong(45)}") + assert(stats.getLong(46) == 10L, s"LongType upper=${stats.getLong(46)}") + + // col10 DecimalType(10,2) (offset 50): lower=1.23, upper=9.87 + assert(stats.getDecimal(50, 10, 2).toJavaBigDecimal.compareTo( + new java.math.BigDecimal("1.23")) == 0, + s"DecimalType lower=${stats.getDecimal(50, 10, 2)}") + assert(stats.getDecimal(51, 10, 2).toJavaBigDecimal.compareTo( + new java.math.BigDecimal("9.87")) == 0, + s"DecimalType upper=${stats.getDecimal(51, 10, 2)}") + + // col11 YearMonthIntervalType (offset 55): lower=12 months (1yr), upper=24 months (2yr) + assert(stats.getInt(55) == 12, s"YearMonthIntervalType lower=${stats.getInt(55)}") + assert(stats.getInt(56) == 24, s"YearMonthIntervalType upper=${stats.getInt(56)}") + + // col12 DayTimeIntervalType (offset 60): lower=1 day, upper=3 days (in microseconds) + assert(stats.getLong(60) == 86400000000L, + s"DayTimeIntervalType lower=${stats.getLong(60)}") + assert(stats.getLong(61) == 259200000000L, + s"DayTimeIntervalType upper=${stats.getLong(61)}") + + // col13 TimeType (offset 65): lower=08:00:00 (28800000000000ns), + // upper=20:00:00 (72000000000000ns) + assert(stats.getLong(65) == 28800000000000L, + s"TimeType lower=${stats.getLong(65)}") + assert(stats.getLong(66) == 72000000000000L, + s"TimeType upper=${stats.getLong(66)}") + + // All null counts should be 0 + (0 until 14).foreach { col => + assert(stats.getInt(col * 5 + 2) == 0, s"nullCount for col$col should be 0") + } + + root.close() + } catch { + case e: Exception => + root.close() + throw e + } + } + + test("collectStatistics produces correct min/max bounds for StringType") { + // StringType in Arrow is stored as VarCharVector (raw UTF-8 bytes). This test covers the + // two distinct code paths in calculateMinMaxString: binary (UTF8_BINARY) and collation-aware + // semantic (collated). The collated case directly exercises the Bug 2 fix: before the fix, + // `case StringType =>` (singleton) did not match collated types so they returned null bounds. + + // UTF8_BINARY: binary-order comparison. + // {"apple", "cherry", "banana"} -> lower=apple, upper=cherry + { + val schema = Seq(AttributeReference("str_col", StringType)()) + val sparkSchema = StructType(schema.map(a => StructField(a.name, a.dataType))) + val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false) + val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) + try { + root.allocateNew() + val strVector = root.getVector("str_col").asInstanceOf[VarCharVector] + strVector.setSafe(0, "apple".getBytes("UTF-8"), 0, 5) + strVector.setSafe(1, "cherry".getBytes("UTF-8"), 0, 6) + strVector.setSafe(2, "banana".getBytes("UTF-8"), 0, 6) + root.setRowCount(3) + val stats = ArrowCachedBatchSerializer.collectStatistics(root, schema) + assert(!stats.isNullAt(0), "UTF8_BINARY lower bound should not be null") + assert(!stats.isNullAt(1), "UTF8_BINARY upper bound should not be null") + assert(stats.getUTF8String(0).toString == "apple", + s"UTF8_BINARY lower expected 'apple', got ${stats.getUTF8String(0)}") + assert(stats.getUTF8String(1).toString == "cherry", + s"UTF8_BINARY upper expected 'cherry', got ${stats.getUTF8String(1)}") + root.close() + } catch { + case e: Exception => root.close(); throw e + } + } + + // UTF8_LCASE (collationId=1): case-insensitive semantic comparison. + // Data: {"Apple", "banana", "Cherry"} + // Binary order: "Apple"(A=65) < "Cherry"(C=67) < "banana"(b=98) -> binary max = "banana" + // Semantic order: apple < banana < cherry -> semantic max = "Cherry" + // Asserting upper == "Cherry" (not "banana") verifies collation-aware semanticCompare is used. + { + val collatedStringType = new StringType(1) // collationId 1 = UTF8_LCASE + val schema = Seq(AttributeReference("str_col", collatedStringType)()) + val sparkSchema = StructType(schema.map(a => StructField(a.name, a.dataType))) + val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false) + val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) + try { + root.allocateNew() + val strVector = root.getVector("str_col").asInstanceOf[VarCharVector] + strVector.setSafe(0, "Apple".getBytes("UTF-8"), 0, 5) + strVector.setSafe(1, "banana".getBytes("UTF-8"), 0, 6) + strVector.setSafe(2, "Cherry".getBytes("UTF-8"), 0, 6) + root.setRowCount(3) + val stats = ArrowCachedBatchSerializer.collectStatistics(root, schema) + assert(!stats.isNullAt(0), "UTF8_LCASE lower bound should not be null") + assert(!stats.isNullAt(1), "UTF8_LCASE upper bound should not be null") + assert(stats.getUTF8String(0).toString == "Apple", + s"UTF8_LCASE lower expected 'Apple' (semantic min), got ${stats.getUTF8String(0)}") + // "Cherry" is the semantic max (case-insensitively: cherry > banana > apple). + // "banana" would be the binary max -- asserting "Cherry" proves semanticCompare is used. + assert(stats.getUTF8String(1).toString == "Cherry", + s"UTF8_LCASE upper expected 'Cherry' (semantic max), got ${stats.getUTF8String(1)}") + root.close() + } catch { + case e: Exception => root.close(); throw e + } + } + } + + test("zstd compression level is honored by createCompressionCodec") { + // Regression test: createCompressionCodec used to rebuild the codec through the + // single-argument factory overload, which silently dropped the level, so every configured + // level compressed at the zstd default. Compress the same batch at an ultra-fast negative + // level and at a high level and assert the high level yields a strictly smaller payload. + def compressedSize(level: Int): Int = { + val schema = Seq(AttributeReference("str_col", StringType)()) + val sparkSchema = StructType(schema.map(a => StructField(a.name, a.dataType))) + val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false) + val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) + try { + root.allocateNew() + val strVector = root.getVector("str_col").asInstanceOf[VarCharVector] + // Compressible but non-trivial corpus: shared structure with per-row variation, so + // different zstd levels produce measurably different output sizes. + (0 until 2000).foreach { i => + val value = + s"user-$i@example.com,record-${i % 97},payload-${i * 2654435761L}".getBytes("UTF-8") + strVector.setSafe(i, value, 0, value.length) + } + root.setRowCount(2000) + val codec = ArrowCachedBatchSerializer.createCompressionCodec("zstd", level) + val recordBatch = new VectorUnloader(root, true, codec, true).getRecordBatch() + try { + ArrowCachedBatchSerializer.serializeBatch(recordBatch).length + } finally { + recordBatch.close() + } + } finally { + root.close() + } + } + + val fastSize = compressedSize(-5) + val highSize = compressedSize(19) + assert(highSize < fastSize, + s"zstd level 19 should compress smaller than level -5, " + + s"got level 19 -> $highSize bytes vs level -5 -> $fastSize bytes; " + + "equal sizes mean the configured level is being ignored") + } + + test("collectStatistics returns null bounds when all Float/Double values are NaN") { + // When every non-null value in a Float or Double column is NaN, calculateMinMaxFloat/Double + // finds no valid (non-NaN) values. hasValue stays false -> returns (null, null) -> null bounds. + // Null bounds disable partition pruning, ensuring NaN-only batches are never incorrectly + // pruned. + val schema = Seq( + AttributeReference("float_col", FloatType)(), + AttributeReference("double_col", DoubleType)() + ) + val sparkSchema = StructType(schema.map(a => StructField(a.name, a.dataType))) + val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false) + val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) + + try { + root.allocateNew() + val floatVector = root.getVector("float_col").asInstanceOf[Float4Vector] + val doubleVector = root.getVector("double_col").asInstanceOf[Float8Vector] + + floatVector.setSafe(0, Float.NaN) + floatVector.setSafe(1, Float.NaN) + doubleVector.setSafe(0, Double.NaN) + doubleVector.setSafe(1, Double.NaN) + root.setRowCount(2) + + val stats = ArrowCachedBatchSerializer.collectStatistics(root, schema) + + // FloatType (col0, offset 0): no valid values -> null bounds + assert(stats.isNullAt(0), "FloatType lower bound should be null when all values are NaN") + assert(stats.isNullAt(1), "FloatType upper bound should be null when all values are NaN") + + // DoubleType (col1, offset 5): no valid values -> null bounds + assert(stats.isNullAt(5), "DoubleType lower bound should be null when all values are NaN") + assert(stats.isNullAt(6), "DoubleType upper bound should be null when all values are NaN") + + root.close() + } catch { + case e: Exception => + root.close() + throw e + } + } + + test("collectStatistics returns null bounds for non-orderable types") { + // BinaryType has no natural ordering, so its lower and upper bounds must be null. + // Null bounds disable partition pruning for those columns, preventing incorrect data exclusion. + // A control IntegerType column confirms bounds are per-type, not per-batch. + val schema = Seq( + AttributeReference("bin_col", BinaryType)(), // VarBinaryVector -- unordered + AttributeReference("int_col", IntegerType)() // IntVector -- orderable (control column) + ) + val sparkSchema = StructType(schema.map(a => StructField(a.name, a.dataType))) + val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false) + val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) + + try { + root.allocateNew() + val binVector = root.getVector("bin_col").asInstanceOf[VarBinaryVector] + val intVector = root.getVector("int_col").asInstanceOf[IntVector] + + binVector.setSafe(0, "hello".getBytes("UTF-8")) + binVector.setSafe(1, "world".getBytes("UTF-8")) + intVector.setSafe(0, 1) + intVector.setSafe(1, 10) + root.setRowCount(2) + + val stats = ArrowCachedBatchSerializer.collectStatistics(root, schema) + + // BinaryType (col0, offset 0): both bounds must be null -- no ordering defined + assert(stats.isNullAt(0), "BinaryType lower bound should be null") + assert(stats.isNullAt(1), "BinaryType upper bound should be null") + assert(stats.getInt(2) == 0, "BinaryType null count should be 0") + assert(stats.getInt(3) == 2, "BinaryType row count should be 2") + + // IntegerType (col1, offset 5): bounds should be non-null and correct + assert(!stats.isNullAt(5), "IntegerType lower bound should not be null") + assert(!stats.isNullAt(6), "IntegerType upper bound should not be null") + assert(stats.getInt(5) == 1, s"IntegerType lower=${stats.getInt(5)}") + assert(stats.getInt(6) == 10, s"IntegerType upper=${stats.getInt(6)}") + + root.close() + } catch { + case e: Exception => + root.close() + throw e + } + } + + // ------------------------------------------------------------------------- + // Collated string bug fixes + // ------------------------------------------------------------------------- + + test("caching collated string columns does not throw UnsupportedOperationException") { + // Bug: readValueFromVector used `case StringType =>` (singleton match) which only matches + // UTF8_BINARY. Collated StringType instances (e.g. UTF8_LCASE, UNICODE) are separate class + // instances and fell through to `case other => throw UnsupportedOperationException(...)`. + // Fix: use `case _: StringType =>` to match all string type instances. + Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI").foreach { collation => + withTable("tbl") { + sql(s"CACHE TABLE tbl AS SELECT col FROM VALUES " + + s"('hello' COLLATE $collation), ('world' COLLATE $collation) AS t(col)") + checkAnswer( + sql("SELECT col FROM tbl"), + Seq(Row("hello"), Row("world"))) + } + } + } + + test("caching collated string columns with null values reads correctly") { + // Verify that null collated string values are also handled correctly in readValueFromVector. + withTable("tbl") { + sql("CACHE TABLE tbl AS SELECT col FROM VALUES " + + "('a' COLLATE UTF8_LCASE), (null), ('B' COLLATE UTF8_LCASE) AS t(col)") + checkAnswer( + sql("SELECT col FROM tbl"), + Seq(Row("a"), Row(null), Row("B"))) + } + } + + test("filter on cached collated column uses correct semantic stats for partition pruning") { + // Bug: collectStatistics used `case StringType =>` (singleton), so collated string columns + // got null min/max stats. When InMemoryTableScanExec evaluated the partition filter + // (e.g. col = 'a') against null bounds, SQL null was coerced to false and the batch was + // incorrectly pruned, causing queries to return empty results even when matching rows exist. + // Fix: use `case st: StringType =>` and pass st.collationId to calculateMinMaxString so + // stats are computed with collation-aware semanticCompare, matching + // DefaultCachedBatchSerializer. + withTable("tbl") { + // Cache the table so InMemoryTableScanExec is used with partition-filter pushdown. + sql("CACHE TABLE tbl AS SELECT col FROM VALUES " + + "('a' COLLATE UTF8_LCASE), ('B' COLLATE UTF8_LCASE), ('c' COLLATE UTF8_LCASE) AS t(col)") + + // 'a' is in the table; with null stats (before fix) the batch would be incorrectly pruned. + checkAnswer(sql("SELECT col FROM tbl WHERE col = 'a'"), Seq(Row("a"))) + // 'B' is in the table; UTF8_LCASE: 'b' == 'B', so this matches 'B'. + checkAnswer(sql("SELECT col FROM tbl WHERE col = 'B'"), Seq(Row("B"))) + // 'z' is not in the table; result should be empty (not incorrectly pruned to empty). + checkAnswer(sql("SELECT col FROM tbl WHERE col = 'z'"), Seq.empty) + } + } + + test("row path stats for collated strings use collation-aware semantic comparison") { + // Bug: createColumnStats used `case StringType =>` (singleton), so collated string columns + // got StringColumnStats(StringType) -- i.e., the wrong collation ID (UTF8_BINARY=0) -- instead + // of StringColumnStats(collatedType). Since StringColumnStats uses semanticCompare(collationId) + // for ordering, passing the wrong collation ID produced binary-order stats for collated + // columns, + // which could incorrectly prune batches for case-insensitive or locale-sensitive collations. + // Fix: use `case st: StringType => new StringColumnStats(st)`. + // + // Test: cache {"Apple", "banana", "Cherry"} with UTF8_LCASE. + // Binary order: "Apple" < "Cherry" < "banana" (uppercase < lowercase in ASCII). + // Semantic (case-insensitive) order: "Apple" < "banana" < "Cherry". + // So semantic lower="Apple", upper="Cherry"; binary lower="Apple", upper="banana". + // A filter WHERE col = 'cherry' should match "Cherry" semantically but not return empty. + val collatedStringType = new StringType(1) // collationId 1 = UTF8_LCASE + val schema = StructType(Seq(StructField("v", collatedStringType, nullable = true))) + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row("Apple"), Row("banana"), Row("Cherry")), 1), + schema).cache() + val stats = cachedStats(df) + // With correct semantic stats: lower="Apple", upper="Cherry" (case-insensitive order) + // "apple" <= "Apple" <= "banana" <= "Cherry" <= "cherry" semantically. + assert(!stats.isNullAt(0), "lower bound should not be null for collated StringType") + assert(!stats.isNullAt(1), "upper bound should not be null for collated StringType") + assert(stats.getUTF8String(0).toString == "Apple") // semantic min (case-insensitive) + assert(stats.getUTF8String(1).toString == "Cherry") // semantic max (case-insensitive) + df.unpersist() + } + + // A WKB-encoded POINT(1 2), used to build Geometry/Geography test values. + private val wkbPoint = "0101000000000000000000F03F0000000000000040" + .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray + + // Representative value for each top-level type the serializer claims to support. This drives the + // data-driven alignment test below: every type here is both asserted supported by + // isSupportedByArrow AND actually cached and read back, so a type that is claimed supported but + // fails during stats collection or read (as top-level geometry/geography once did) is caught. + private val sampleVariant = { + val v = VariantBuilder.parseJson("""{"a":1}""", false) + new VariantVal(v.getValue, v.getMetadata) + } + + private val topLevelTypeSamples: Seq[(DataType, Any)] = Seq( + (BooleanType, true), + (ByteType, 1.toByte), + (ShortType, 1.toShort), + (IntegerType, 1), + (LongType, 1L), + (FloatType, 1.0f), + (DoubleType, 1.0), + (StringType, "x"), + (BinaryType, Array[Byte](1, 2, 3)), + (NullType, null), + (DateType, Date.valueOf("2020-01-01")), + (TimestampType, Timestamp.valueOf("2020-01-01 00:00:00")), + (TimestampNTZType, LocalDateTime.parse("2020-01-01T00:00:00")), + (TimeType(6), LocalTime.parse("12:00:00")), + (DecimalType(10, 2), BigDecimal("1.23").bigDecimal), + (YearMonthIntervalType(), Period.ofMonths(3)), + (DayTimeIntervalType(), Duration.ofSeconds(5)), + (CalendarIntervalType, new CalendarInterval(1, 2, 3L)), + (ArrayType(IntegerType), Seq(1, 2, 3)), + (StructType(Seq(StructField("a", IntegerType))), Row(1)), + (MapType(StringType, IntegerType), Map("a" -> 1)), + (new ExamplePointUDT(), new ExamplePoint(1.0, 2.0)), + (VariantType, sampleVariant), + (GeometryType(4326), Geometry.fromWKB(wkbPoint, 4326)), + (GeographyType(4326), Geography.fromWKB(wkbPoint, 4326))) + + // Maps a DataType to a stable key identifying which isSupportedByArrow branch claims it. Two + // types share a key iff they are accepted by the same case arm. Used by the coverage test below + // to assert topLevelTypeSamples exercises every branch, so the hand-written sample list stays in + // sync with isSupportedByArrow as new branches are added. + private def supportedBranchKey(dt: DataType): String = dt match { + case BooleanType => "boolean" + case ByteType => "byte" + case ShortType => "short" + case IntegerType => "integer" + case LongType => "long" + case FloatType => "float" + case DoubleType => "double" + case _: StringType => "string" + case BinaryType => "binary" + case NullType => "null" + case _: DecimalType => "decimal" + case DateType => "date" + case TimestampType => "timestamp" + case TimestampNTZType => "timestampNTZ" + case _: TimeType => "time" + case _: YearMonthIntervalType => "yearMonthInterval" + case _: DayTimeIntervalType => "dayTimeInterval" + case CalendarIntervalType => "calendarInterval" + case _: ArrayType => "array" + case _: StructType => "struct" + case _: MapType => "map" + case _: UserDefinedType[_] => "udt" + case _: GeometryType => "geometry" + case _: GeographyType => "geography" + case _: VariantType => "variant" + case _ => s"UNSUPPORTED($dt)" + } + + // One representative type per isSupportedByArrow branch. supportedBranchKey must return a + // distinct, non-UNSUPPORTED key for each, mirroring the match arms in isSupportedByArrow. When a + // branch is added to isSupportedByArrow, add its representative here; the assertions below then + // force a matching entry in topLevelTypeSamples. + private val supportedBranchRepresentatives: Seq[DataType] = Seq( + BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, + StringType, BinaryType, NullType, DecimalType(10, 2), DateType, TimestampType, + TimestampNTZType, TimeType(6), YearMonthIntervalType(), DayTimeIntervalType(), + CalendarIntervalType, ArrayType(IntegerType), StructType(Seq(StructField("a", IntegerType))), + MapType(StringType, IntegerType), new ExamplePointUDT(), GeometryType(4326), + GeographyType(4326), VariantType) + + test("topLevelTypeSamples covers every isSupportedByArrow branch") { + // Cross-checks the hand-written sample list against isSupportedByArrow so the two cannot drift: + // every representative type is actually claimed supported, and every branch it represents has + // at least one sample exercising the cache+read path in the test below. + val representativeKeys = supportedBranchRepresentatives.map { dt => + assert(ArrowUtils.isSupportedByArrow(dt), + s"representative type $dt is expected to be claimed supported by isSupportedByArrow") + val key = supportedBranchKey(dt) + assert(!key.startsWith("UNSUPPORTED"), + s"supportedBranchKey has no branch for representative type $dt; mirror the new " + + "isSupportedByArrow case here") + key + } + // No two representatives may collapse to the same branch, or a branch could go uncovered while + // appearing covered. + assert(representativeKeys.distinct.size == representativeKeys.size, + s"representatives map to duplicate branch keys: " + + s"${representativeKeys.diff(representativeKeys.distinct).distinct.mkString(", ")}") + + val sampleKeys = topLevelTypeSamples.map { case (dt, _) => supportedBranchKey(dt) }.toSet + val uncovered = representativeKeys.filterNot(sampleKeys.contains) + assert(uncovered.isEmpty, + s"isSupportedByArrow branches with no entry in topLevelTypeSamples: " + + s"${uncovered.mkString(", ")}. Add a representative value to topLevelTypeSamples.") + } + + test("every type claimed supported by isSupportedByArrow can be cached and read back") { + // Guards against the failure mode where a type is added to isSupportedByArrow but the stats + // collector (createColumnStats) or read path (needsFallback/ArrowColumnReader) is not updated + // to match. Driven by isSupportedByArrow itself rather than a hand-maintained list (the + // coverage test above enforces that every isSupportedByArrow branch appears here), so the + // claim and the implementation are cross-checked on the same set of types. + Seq(false, true).foreach { vectorized => + withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { + topLevelTypeSamples.foreach { case (dt, value) => + assert(ArrowUtils.isSupportedByArrow(dt), + s"test sample type $dt is expected to be claimed supported by isSupportedByArrow") + // Compare the cached read against an uncached baseline built from the same input, so any + // corruption that preserves row count (wrong value, truncated bytes, dropped fields) is + // caught, not just the row count. checkAnswer does type-aware comparison, covering + // binary, arrays, decimals, geometry/geography, variant, etc. Use two rows including a + // null to also exercise null handling per type. + val inputs = Seq(value, null) + val expected = singlePartDf(inputs, dt).collect().toSeq + val df = singlePartDf(inputs, dt).cache() + try { + assert(df.count() == 2, s"count mismatch for $dt (vectorized=$vectorized)") + checkAnswer(df, expected) + } finally { + df.unpersist() + InMemoryRelation.clearSerializer() + } + } + } + } + } + + test("duplicated column names roundtrip through the cache") { + // Arrow schemas are allowed to contain duplicate field names, and nothing in the cache path + // is keyed by name: vectors are read positionally (getFieldVectors / getVector(index)) and + // column pruning maps selected attributes to cache columns by exprId. The name-collision + // handling needed for the Python<->JVM Arrow exchange does not apply here because cached + // batches never cross into Python. This test pins that property for both read paths and for + // pruning a single one of the duplicated columns. + Seq(false, true).foreach { vectorized => + withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { + val df = spark.range(3).selectExpr("id as a", "(id * 10) as a", "string(id) as a") + df.cache() + try { + checkAnswer(df, Seq(Row(0L, 0L, "0"), Row(1L, 10L, "1"), Row(2L, 20L, "2"))) + // Selecting a duplicated name directly is ambiguous, so rename above the cached + // relation and project one column. The scan still prunes against the duplicate-name + // cache schema, resolving the column by exprId. + val pruned = df.toDF("x", "y", "z").select("y") + checkAnswer(pruned, Seq(Row(0L), Row(10L), Row(20L))) + } finally { + df.unpersist() + InMemoryRelation.clearSerializer() + } + } + } + } + + test("top-level geometry and geography roundtrip through the cache") { + Seq(false, true).foreach { vectorized => + withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { + Seq[(DataType, Any)]( + (GeometryType(4326), Geometry.fromWKB(wkbPoint, 4326)), + (GeographyType(4326), Geography.fromWKB(wkbPoint, 4326))).foreach { case (dt, value) => + val expected = singlePartDf(Seq(value, null), dt).collect().toSeq + val df = singlePartDf(Seq(value, null), dt).cache() + try { + // Validate the actual roundtripped value (WKB + SRID), not just nullness: the cached + // read must equal the uncached baseline. checkAnswer compares the geometry/geography + // value, so a corrupted WKB or dropped SRID would fail here. + checkAnswer(df, expected) + // Stats: geometry/geography use GeoColumnStats, so no min/max bounds but a + // null count of 1. Stats layout: [lower(0), upper(1), nullCount(2), count(3), + // sizeInBytes(4)]. + val stats = cachedStats(df) + assert(stats.isNullAt(0), s"$dt should have null lower bound") + assert(stats.isNullAt(1), s"$dt should have null upper bound") + assert(stats.getInt(2) == 1, s"$dt null count should be 1") + // sizeInBytes must reflect the BinaryView payload via getBinaryView, not be zero. + // The non-null value contributes its WKB byte length, so the total must exceed the + // WKB size; a regression to row.getBinary or a skipped size would fail this. + assert(stats.getLong(4) > wkbPoint.length, + s"$dt sizeInBytes (${stats.getLong(4)}) should exceed the WKB length " + + s"(${wkbPoint.length})") + } finally { + df.unpersist() + InMemoryRelation.clearSerializer() + } + } + } + } + } + + test("columnar input backed by LargeVarCharVector roundtrips via the slow path") { + // ArrowColumnVector accepts LargeVarCharVector (64-bit offsets) for StringType. The zero-copy + // path serializes/reloads under a largeVarTypes=false schema (32-bit offsets), which would + // corrupt such data, so the serializer must fall back to the row-based slow path. Build the + // ColumnarBatch inside the task to avoid serializing it to executors. + val schema = Seq(AttributeReference("v", StringType, nullable = true)()) + val conf = spark.sessionState.conf + val ser = new ArrowCachedBatchSerializer + val batchRdd = spark.sparkContext.parallelize(Seq(0), 1).mapPartitions { _ => + val alloc = ArrowUtils.rootAllocator.newChildAllocator("test-large-varchar", 0, Long.MaxValue) + val lv = new LargeVarCharVector("v", alloc) + lv.allocateNew(2) + lv.setSafe(0, "hello".getBytes("UTF-8")) + lv.setSafe(1, "world".getBytes("UTF-8")) + lv.setValueCount(2) + Iterator(new ColumnarBatch(Array[ColumnVector](new ArrowColumnVector(lv)), 2)) + } + val cached = ser.convertColumnarBatchToCachedBatch( + batchRdd, schema, StorageLevel.MEMORY_ONLY, conf) + cached.persist() + try { + val values = ser.convertCachedBatchToInternalRow(cached, schema, schema, conf) + .map(_.getString(0)).collect() + assert(values.sorted.sameElements(Array("hello", "world")), + s"expected [hello, world] but got [${values.mkString(", ")}]") + } finally { + cached.unpersist() + } + } + + test("CalendarInterval microsecond overflow produces a clear diagnostic") { + // Arrow stores intervals in nanoseconds, so a CalendarInterval whose microseconds exceed + // Long.MaxValue/1000 overflows when written. The serializer must surface a clear error naming + // the type/limit rather than an opaque "long overflow" ArithmeticException. A normal-range + // value must still cache fine. + val normal = new CalendarInterval(1, 2, 3000000L) + val normalDf = singlePartDf(Seq(normal), CalendarIntervalType).cache() + try { + assert(normalDf.count() == 1) + } finally { + normalDf.unpersist() + InMemoryRelation.clearSerializer() + } + + val overflow = new CalendarInterval(0, 0, Long.MaxValue / 1000L + 1L) + val overflowDf = singlePartDf(Seq(overflow), CalendarIntervalType).cache() + try { + val e = intercept[Exception](overflowDf.count()) + assert(Utils.exceptionString(e).contains("Arrow cache cannot represent a CalendarInterval"), + s"expected a clear CalendarInterval overflow message, got: ${Utils.exceptionString(e)}") + } finally { + overflowDf.unpersist() + InMemoryRelation.clearSerializer() + } + } + + test("nonpositive maxRecordsPerBatch caches all rows in a single batch") { + // A nonpositive maxRecordsPerBatch means unlimited; without the `<= 0` guard the write + // iterator would emit zero-row batches forever instead of finishing. + Seq("0", "-1").foreach { v => + withSQLConf(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> v) { + val df = spark.range(0, 100).repartition(1).cache() + try { + assert(df.count() == 100, s"maxRecordsPerBatch=$v") + assert(df.collect().length == 100, s"maxRecordsPerBatch=$v") + } finally { + df.unpersist() + InMemoryRelation.clearSerializer() + } + } + } + } + + test("row path splits batches by record and byte limits") { + val schema = Seq(AttributeReference("v", IntegerType, nullable = true)()) + val ser = new ArrowCachedBatchSerializer + + // Record limit: 95 rows in one partition with maxRecordsPerBatch=10 -> ceil(95 / 10) = 10 + // cached batches, none exceeding 10 rows. + withSQLConf(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> "10") { + val conf = spark.sessionState.conf + val rowRdd = spark.sparkContext.parallelize(0 until 95, 1).map(InternalRow(_)) + val cached = ser.convertInternalRowToCachedBatch( + rowRdd, schema, StorageLevel.MEMORY_ONLY, conf) + val batchRows = cached.map(_.numRows).collect() + assert(batchRows.length == 10, s"expected 10 batches, got ${batchRows.length}") + assert(batchRows.sum == 95, "all rows must be preserved") + assert(batchRows.forall(_ <= 10), "no batch may exceed the record limit") + } + + // Byte limit: with a record limit high enough not to bind and a tiny byte limit, the batch is + // cut by estimated bytes instead, so more than one batch is produced and each is small. This + // pins that maxBytesPerBatch is honored (it was previously ignored on this path), matching the + // behavior of the other Arrow writers. + withSQLConf( + SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> "100000", + SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> "64") { + val conf = spark.sessionState.conf + val rowRdd = spark.sparkContext.parallelize(0 until 95, 1).map(InternalRow(_)) + val cached = ser.convertInternalRowToCachedBatch( + rowRdd, schema, StorageLevel.MEMORY_ONLY, conf) + val batchRows = cached.map(_.numRows).collect() + assert(batchRows.length > 1, + s"a tiny byte limit should force multiple batches, got ${batchRows.length}") + assert(batchRows.sum == 95, "all rows must be preserved") + } + } + + test("row read does not drop rows after an empty cached batch") { + // A zero-row cached batch (legal input from a columnar source) must not terminate the row + // iterator early: subsequent non-empty batches must still be read. + val schema = Seq(AttributeReference("v", IntegerType, nullable = true)()) + val conf = spark.sessionState.conf + val ser = new ArrowCachedBatchSerializer + // One empty ColumnarBatch followed by a one-row batch, each built inside the task. + val batchRdd = spark.sparkContext.parallelize(Seq(0), 1).mapPartitions { _ => + def intBatch(values: Int*): ColumnarBatch = { + val alloc = ArrowUtils.rootAllocator.newChildAllocator("test-empty", 0, Long.MaxValue) + val iv = new IntVector("v", alloc) + iv.allocateNew(values.length) + values.zipWithIndex.foreach { case (x, i) => iv.setSafe(i, x) } + iv.setValueCount(values.length) + new ColumnarBatch(Array[ColumnVector](new ArrowColumnVector(iv)), values.length) + } + Iterator(intBatch(), intBatch(42)) + } + val cached = ser.convertColumnarBatchToCachedBatch( + batchRdd, schema, StorageLevel.MEMORY_ONLY, conf) + cached.persist() + try { + val values = ser.convertCachedBatchToInternalRow(cached, schema, schema, conf) + .map(_.getInt(0)).collect() + assert(values.sameElements(Array(42)), s"expected [42] but got [${values.mkString(", ")}]") + } finally { + cached.unpersist() + } + } + + test("columnar read with prefetch does not release the in-use batch's buffers") { + // With prefetch enabled, the next batch is deserialized on a background thread while the + // current batch is consumed. The previous root must only be closed on the consumer thread + // (in next()), never by the background prefetch; otherwise the ArrowColumnVectors backing the + // batch currently held by the consumer point at released memory. Reading the held batch after + // giving the background prefetch time to run reproduces that use-after-free if reintroduced. + withSQLConf(SQLConf.ARROW_CACHE_PREFETCH_ENABLED.key -> "true") { + val schema = Seq(AttributeReference("v", IntegerType, nullable = true)()) + val conf = spark.sessionState.conf + val ser = new ArrowCachedBatchSerializer + val batchRdd = spark.sparkContext.parallelize(Seq(0), 1).mapPartitions { _ => + (0 until 5).iterator.map { x => + val alloc = ArrowUtils.rootAllocator.newChildAllocator(s"prefetch-$x", 0, Long.MaxValue) + val iv = new IntVector("v", alloc) + iv.allocateNew(1) + iv.setSafe(0, x * 10) + iv.setValueCount(1) + new ColumnarBatch(Array[ColumnVector](new ArrowColumnVector(iv)), 1) + } + } + val cached = ser.convertColumnarBatchToCachedBatch( + batchRdd, schema, StorageLevel.MEMORY_ONLY, conf) + cached.persist() + try { + val values = ser.convertCachedBatchToColumnarBatch(cached, schema, schema, conf) + .mapPartitions { it => + val out = scala.collection.mutable.ArrayBuffer[Int]() + while (it.hasNext) { + val batch = it.next() // hold exactly one batch, per the ColumnarBatch contract + Thread.sleep(20) // give the background prefetch a chance to run before reading + out += batch.getRow(0).getInt(0) + } + out.iterator + }.collect() + assert(values.sorted.sameElements(Array(0, 10, 20, 30, 40)), + s"expected [0, 10, 20, 30, 40] but got [${values.sorted.mkString(", ")}]") + } finally { + cached.unpersist() + } + } + } + + test("drainAndClosePrefetch closes the produced root even on an interrupted thread") { + // Reproduces the killed-task race: the cleanup listener can run with the task thread already + // interrupted. drainAndClosePrefetch must still join the worker and close the root it produced, + // so the child allocator can be closed without "Memory was leaked by query", and must restore + // the interrupt afterwards. + val arrowSchema = ArrowUtils.toArrowSchema( + StructType(Seq(StructField("v", IntegerType))), "UTC", false, false) + val alloc = ArrowUtils.rootAllocator.newChildAllocator("test-drain-interrupt", 0, Long.MaxValue) + val executor = java.util.concurrent.Executors.newSingleThreadExecutor() + val future = executor.submit(new java.util.concurrent.Callable[VectorSchemaRoot] { + override def call(): VectorSchemaRoot = { + val root = VectorSchemaRoot.create(arrowSchema, alloc) + root.allocateNew() // allocate buffers so a leak would be detected on alloc.close() + root + } + }) + // Interrupt the current thread before cleanup, then assert cleanup still drains/closes. + Thread.currentThread().interrupt() + val result = ArrowCachedBatchSerializer.drainAndClosePrefetch(executor, future) + assert(result == null) + assert(Thread.interrupted(), "the interrupt status must be restored (and cleared here)") + // If the produced root was not closed, this throws "Memory was leaked by query". + alloc.close() + } +} + +/** + * Tests that ArrowCachedBatch and ArrowCachedBatchSerializer are registered in KryoSerializer. + * Without the registration, persisting with DISK_ONLY storage level would fail when + * spark.kryo.registrationRequired=true because Kryo rejects unregistered classes. + */ +class ArrowCachedBatchKryoRegistrationSuite extends QueryTest with SharedSparkSession { + + override def sparkConf: SparkConf = super.sparkConf + .set(StaticSQLConf.SPARK_CACHE_SERIALIZER.key, classOf[ArrowCachedBatchSerializer].getName) + .set("spark.kryo.registrationRequired", "true") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + + override def beforeAll(): Unit = { + super.beforeAll() + InMemoryRelation.clearSerializer() + } + + override def afterAll(): Unit = { + InMemoryRelation.clearSerializer() + super.afterAll() + } + + test("ArrowCachedBatch and ArrowCachedBatchSerializer are registered in KryoSerializer") { + withTable("t1") { + sql("CREATE TABLE t1 AS SELECT 1 AS a") + checkAnswer(sql("SELECT * FROM t1").persist(StorageLevel.DISK_ONLY), Seq(Row(1))) + } + } +}