From 1fc66f68703e0b14e03fe0ed5ca93e9af20f41a9 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Thu, 1 Apr 2021 23:10:34 -0700 Subject: [PATCH] [SPARK-34862][SQL] Support nested column in ORC vectorized reader ### What changes were proposed in this pull request? This PR is to support nested column type in Spark ORC vectorized reader. Currently ORC vectorized reader [does not support nested column type (struct, array and map)](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala#L138). We implemented nested column vectorized reader for FB-ORC in our internal fork of Spark. We are seeing performance improvement compared to non-vectorized reader when reading nested columns. In addition, this can also help improve the non-nested column performance when reading non-nested and nested columns together in one query. Before this PR: * `OrcColumnVector` is the implementation class for Spark's `ColumnVector` to wrap Hive's/ORC's `ColumnVector` to read `AtomicType` data. After this PR: * `OrcColumnVector` is an abstract class to keep interface being shared between multiple implementation class of orc column vectors, namely `OrcAtomicColumnVector` (for `AtomicType`), `OrcArrayColumnVector` (for `ArrayType`), `OrcMapColumnVector` (for `MapType`), `OrcStructColumnVector` (for `StructType`). So the original logic to read `AtomicType` data is moved from `OrcColumnVector` to `OrcAtomicColumnVector`. The abstract class of `OrcColumnVector` is needed here because of supporting nested column (i.e. nested column vectors). * A utility method `OrcColumnVectorUtils.toOrcColumnVector` is added to create Spark's `OrcColumnVector` from Hive's/ORC's `ColumnVector`. * A new user-facing config `spark.sql.orc.enableNestedColumnVectorizedReader` is added to control enabling/disabling vectorized reader for nested columns. The default value is false (i.e. disabling by default). For certain tables having deep nested columns, vectorized reader might take too much memory for each sub-column vectors, compared to non-vectorized reader. So providing a config here to work around OOM for query reading wide and deep nested columns if any. We plan to enable it by default on 3.3. Leave it disable in 3.2 in case for any unknown bugs. ### Why are the changes needed? Improve query performance when reading nested columns from ORC file format. Tested with locally adding a small benchmark in `OrcReadBenchmark.scala`. Seeing more than 1x run time improvement. ``` Running benchmark: SQL Nested Column Scan Running case: Native ORC MR Stopped after 2 iterations, 37850 ms Running case: Native ORC Vectorized (Enabled Nested Column) Stopped after 2 iterations, 15892 ms Running case: Native ORC Vectorized (Disabled Nested Column) Stopped after 2 iterations, 37954 ms Running case: Hive built-in ORC Stopped after 2 iterations, 35118 ms Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.15.7 Intel(R) Core(TM) i9-9980HK CPU 2.40GHz SQL Nested Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------ Native ORC MR 18706 18925 310 0.1 17839.6 1.0X Native ORC Vectorized (Enabled Nested Column) 7625 7946 455 0.1 7271.6 2.5X Native ORC Vectorized (Disabled Nested Column) 18415 18977 796 0.1 17561.5 1.0X Hive built-in ORC 17469 17559 127 0.1 16660.1 1.1X ``` Benchmark: ``` nestedColumnScanBenchmark(1024 * 1024) def nestedColumnScanBenchmark(values: Int): Unit = { val benchmark = new Benchmark(s"SQL Nested Column Scan", values, output = output) withTempPath { dir => withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { import spark.implicits._ spark.range(values).map(_ => Random.nextLong).map { x => val arrayOfStructColumn = (0 until 5).map(i => (x + i, s"$x" * 5)) val mapOfStructColumn = Map( s"$x" -> (x * 0.1, (x, s"$x" * 100)), (s"$x" * 2) -> (x * 0.2, (x, s"$x" * 200)), (s"$x" * 3) -> (x * 0.3, (x, s"$x" * 300))) (arrayOfStructColumn, mapOfStructColumn) }.toDF("col1", "col2") .createOrReplaceTempView("t1") prepareTable(dir, spark.sql(s"SELECT * FROM t1")) benchmark.addCase("Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM nativeOrcTable").noop() } } benchmark.addCase("Native ORC Vectorized (Enabled Nested Column)") { _ => spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM nativeOrcTable").noop() } benchmark.addCase("Native ORC Vectorized (Disabled Nested Column)") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "false") { spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM nativeOrcTable").noop() } } benchmark.addCase("Hive built-in ORC") { _ => spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM hiveOrcTable").noop() } benchmark.run() } } } ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added one simple test in `OrcSourceSuite.scala` to verify correctness. Definitely need more unit tests and add benchmark here, but I want to first collect feedback before crafting more tests. Closes #31958 from c21/orc-vector. Authored-by: Cheng Su Signed-off-by: Liang-Chi Hsieh --- project/MimaExcludes.scala | 17 +- .../apache/spark/sql/internal/SQLConf.scala | 10 ++ .../datasources/orc/OrcArrayColumnVector.java | 115 +++++++++++++ .../orc/OrcAtomicColumnVector.java | 161 ++++++++++++++++++ .../datasources/orc/OrcColumnVector.java | 151 ++-------------- .../datasources/orc/OrcColumnVectorUtils.java | 71 ++++++++ .../orc/OrcColumnarBatchReader.java | 3 +- .../datasources/orc/OrcMapColumnVector.java | 118 +++++++++++++ .../orc/OrcStructColumnVector.java | 105 ++++++++++++ .../datasources/orc/OrcFileFormat.scala | 18 +- .../datasources/orc/OrcSourceSuite.scala | 29 ++++ 11 files changed, 657 insertions(+), 141 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcArrayColumnVector.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcAtomicColumnVector.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVectorUtils.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcMapColumnVector.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcStructColumnVector.java diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 4f879b881a6fe..dc11f338a006e 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -40,7 +40,22 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.connector.write.V1WriteBuilder"), // [SPARK-33955] Add latest offsets to source progress - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.SourceProgress.this") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.SourceProgress.this"), + + // [SPARK-34862][SQL] Support nested column in ORC vectorized reader + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getBoolean"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getByte"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getShort"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getInt"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getLong"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getFloat"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getDouble"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getDecimal"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getUTF8String"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getBinary"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getArray"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getMap"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getChild") ) // Exclude rules for 3.1.x diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5c5cb52f30951..d91bb59f68c44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -839,6 +839,13 @@ object SQLConf { .intConf .createWithDefault(4096) + val ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED = + buildConf("spark.sql.orc.enableNestedColumnVectorizedReader") + .doc("Enables vectorized orc decoding for nested column.") + .version("3.2.0") + .booleanConf + .createWithDefault(false) + val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .version("1.4.0") @@ -3339,6 +3346,9 @@ class SQLConf extends Serializable with Logging { def orcVectorizedReaderBatchSize: Int = getConf(ORC_VECTORIZED_READER_BATCH_SIZE) + def orcVectorizedReaderNestedColumnEnabled: Boolean = + getConf(ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED) + def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcArrayColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcArrayColumnVector.java new file mode 100644 index 0000000000000..6e13e97b4cbcc --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcArrayColumnVector.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; + +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column vector implementation for Spark's {@link ArrayType}. + */ +public class OrcArrayColumnVector extends OrcColumnVector { + private final OrcColumnVector data; + private final long[] offsets; + private final long[] lengths; + + OrcArrayColumnVector( + DataType type, + ColumnVector vector, + OrcColumnVector data, + long[] offsets, + long[] lengths) { + + super(type, vector); + + this.data = data; + this.offsets = offsets; + this.lengths = lengths; + } + + @Override + public ColumnarArray getArray(int rowId) { + return new ColumnarArray(data, (int) offsets[rowId], (int) lengths[rowId]); + } + + @Override + public boolean getBoolean(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException(); + } + + @Override + public UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getBinary(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcAtomicColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcAtomicColumnVector.java new file mode 100644 index 0000000000000..c2d8334d928c0 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcAtomicColumnVector.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import java.math.BigDecimal; + +import org.apache.hadoop.hive.ql.exec.vector.*; + +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.catalyst.util.RebaseDateTime; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column vector implementation for Spark's AtomicType. + */ +public class OrcAtomicColumnVector extends OrcColumnVector { + private final boolean isTimestamp; + private final boolean isDate; + + // Column vector for each type. Only 1 is populated for any type. + private LongColumnVector longData; + private DoubleColumnVector doubleData; + private BytesColumnVector bytesData; + private DecimalColumnVector decimalData; + private TimestampColumnVector timestampData; + + OrcAtomicColumnVector(DataType type, ColumnVector vector) { + super(type, vector); + + if (type instanceof TimestampType) { + isTimestamp = true; + } else { + isTimestamp = false; + } + + if (type instanceof DateType) { + isDate = true; + } else { + isDate = false; + } + + if (vector instanceof LongColumnVector) { + longData = (LongColumnVector) vector; + } else if (vector instanceof DoubleColumnVector) { + doubleData = (DoubleColumnVector) vector; + } else if (vector instanceof BytesColumnVector) { + bytesData = (BytesColumnVector) vector; + } else if (vector instanceof DecimalColumnVector) { + decimalData = (DecimalColumnVector) vector; + } else if (vector instanceof TimestampColumnVector) { + timestampData = (TimestampColumnVector) vector; + } else { + throw new UnsupportedOperationException(); + } + } + + @Override + public boolean getBoolean(int rowId) { + return longData.vector[getRowIndex(rowId)] == 1; + } + + @Override + public byte getByte(int rowId) { + return (byte) longData.vector[getRowIndex(rowId)]; + } + + @Override + public short getShort(int rowId) { + return (short) longData.vector[getRowIndex(rowId)]; + } + + @Override + public int getInt(int rowId) { + int value = (int) longData.vector[getRowIndex(rowId)]; + if (isDate) { + return RebaseDateTime.rebaseJulianToGregorianDays(value); + } else { + return value; + } + } + + @Override + public long getLong(int rowId) { + int index = getRowIndex(rowId); + if (isTimestamp) { + return DateTimeUtils.fromJavaTimestamp(timestampData.asScratchTimestamp(index)); + } else { + return longData.vector[index]; + } + } + + @Override + public float getFloat(int rowId) { + return (float) doubleData.vector[getRowIndex(rowId)]; + } + + @Override + public double getDouble(int rowId) { + return doubleData.vector[getRowIndex(rowId)]; + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; + BigDecimal data = decimalData.vector[getRowIndex(rowId)].getHiveDecimal().bigDecimalValue(); + return Decimal.apply(data, precision, scale); + } + + @Override + public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; + int index = getRowIndex(rowId); + BytesColumnVector col = bytesData; + return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]); + } + + @Override + public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; + int index = getRowIndex(rowId); + byte[] binary = new byte[bytesData.length[index]]; + System.arraycopy(bytesData.vector[index], bytesData.start[index], binary, 0, binary.length); + return binary; + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index 6e55fedfc4deb..0becd2572f99c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -17,75 +17,29 @@ package org.apache.spark.sql.execution.datasources.orc; -import java.math.BigDecimal; +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; -import org.apache.hadoop.hive.ql.exec.vector.*; - -import org.apache.spark.sql.catalyst.util.DateTimeUtils; -import org.apache.spark.sql.catalyst.util.RebaseDateTime; import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.DateType; -import org.apache.spark.sql.types.Decimal; -import org.apache.spark.sql.types.TimestampType; -import org.apache.spark.sql.vectorized.ColumnarArray; -import org.apache.spark.sql.vectorized.ColumnarMap; -import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.sql.vectorized.ColumnarBatch; /** - * A column vector class wrapping Hive's ColumnVector. Because Spark ColumnarBatch only accepts - * Spark's vectorized.ColumnVector, this column vector is used to adapt Hive ColumnVector with - * Spark ColumnarVector. + * A column vector interface wrapping Hive's {@link ColumnVector}. + * + * Because Spark {@link ColumnarBatch} only accepts Spark's vectorized.ColumnVector, + * this column vector is used to adapt Hive ColumnVector with Spark ColumnarVector. */ -public class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVector { - private ColumnVector baseData; - private LongColumnVector longData; - private DoubleColumnVector doubleData; - private BytesColumnVector bytesData; - private DecimalColumnVector decimalData; - private TimestampColumnVector timestampData; - private final boolean isTimestamp; - private final boolean isDate; - +public abstract class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVector { + private final ColumnVector baseData; private int batchSize; OrcColumnVector(DataType type, ColumnVector vector) { super(type); - if (type instanceof TimestampType) { - isTimestamp = true; - } else { - isTimestamp = false; - } - - if (type instanceof DateType) { - isDate = true; - } else { - isDate = false; - } - baseData = vector; - if (vector instanceof LongColumnVector) { - longData = (LongColumnVector) vector; - } else if (vector instanceof DoubleColumnVector) { - doubleData = (DoubleColumnVector) vector; - } else if (vector instanceof BytesColumnVector) { - bytesData = (BytesColumnVector) vector; - } else if (vector instanceof DecimalColumnVector) { - decimalData = (DecimalColumnVector) vector; - } else if (vector instanceof TimestampColumnVector) { - timestampData = (TimestampColumnVector) vector; - } else { - throw new UnsupportedOperationException(); - } - } - - public void setBatchSize(int batchSize) { - this.batchSize = batchSize; } @Override public void close() { - } @Override @@ -112,97 +66,18 @@ public int numNulls() { } } - /* A helper method to get the row index in a column. */ - private int getRowIndex(int rowId) { - return baseData.isRepeating ? 0 : rowId; - } - @Override public boolean isNullAt(int rowId) { return baseData.isNull[getRowIndex(rowId)]; } - @Override - public boolean getBoolean(int rowId) { - return longData.vector[getRowIndex(rowId)] == 1; - } - - @Override - public byte getByte(int rowId) { - return (byte) longData.vector[getRowIndex(rowId)]; - } - - @Override - public short getShort(int rowId) { - return (short) longData.vector[getRowIndex(rowId)]; - } - - @Override - public int getInt(int rowId) { - int value = (int) longData.vector[getRowIndex(rowId)]; - if (isDate) { - return RebaseDateTime.rebaseJulianToGregorianDays(value); - } else { - return value; - } - } - - @Override - public long getLong(int rowId) { - int index = getRowIndex(rowId); - if (isTimestamp) { - return DateTimeUtils.fromJavaTimestamp(timestampData.asScratchTimestamp(index)); - } else { - return longData.vector[index]; - } - } - - @Override - public float getFloat(int rowId) { - return (float) doubleData.vector[getRowIndex(rowId)]; - } - - @Override - public double getDouble(int rowId) { - return doubleData.vector[getRowIndex(rowId)]; - } - - @Override - public Decimal getDecimal(int rowId, int precision, int scale) { - if (isNullAt(rowId)) return null; - BigDecimal data = decimalData.vector[getRowIndex(rowId)].getHiveDecimal().bigDecimalValue(); - return Decimal.apply(data, precision, scale); - } - @Override - public UTF8String getUTF8String(int rowId) { - if (isNullAt(rowId)) return null; - int index = getRowIndex(rowId); - BytesColumnVector col = bytesData; - return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]); - } - - @Override - public byte[] getBinary(int rowId) { - if (isNullAt(rowId)) return null; - int index = getRowIndex(rowId); - byte[] binary = new byte[bytesData.length[index]]; - System.arraycopy(bytesData.vector[index], bytesData.start[index], binary, 0, binary.length); - return binary; - } - - @Override - public ColumnarArray getArray(int rowId) { - throw new UnsupportedOperationException(); - } - - @Override - public ColumnarMap getMap(int rowId) { - throw new UnsupportedOperationException(); + public void setBatchSize(int batchSize) { + this.batchSize = batchSize; } - @Override - public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { - throw new UnsupportedOperationException(); + /* A helper method to get the row index in a column. */ + protected int getRowIndex(int rowId) { + return baseData.isRepeating ? 0 : rowId; } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVectorUtils.java new file mode 100644 index 0000000000000..3bc7cc8f80142 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVectorUtils.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import org.apache.hadoop.hive.ql.exec.vector.*; + +import org.apache.spark.sql.types.*; + +/** + * Utility class for {@link OrcColumnVector}. + */ +class OrcColumnVectorUtils { + + /** + * Convert a Hive's {@link ColumnVector} to a Spark's {@link OrcColumnVector}. + * + * @param type The data type of column vector + * @param vector Hive's column vector + * @return Spark's column vector + */ + static OrcColumnVector toOrcColumnVector(DataType type, ColumnVector vector) { + if (vector instanceof LongColumnVector || + vector instanceof DoubleColumnVector || + vector instanceof BytesColumnVector || + vector instanceof DecimalColumnVector || + vector instanceof TimestampColumnVector) { + return new OrcAtomicColumnVector(type, vector); + } else if (vector instanceof StructColumnVector) { + StructColumnVector structVector = (StructColumnVector) vector; + OrcColumnVector[] fields = new OrcColumnVector[structVector.fields.length]; + int ordinal = 0; + for (StructField f : ((StructType) type).fields()) { + fields[ordinal] = toOrcColumnVector(f.dataType(), structVector.fields[ordinal]); + ordinal++; + } + return new OrcStructColumnVector(type, vector, fields); + } else if (vector instanceof ListColumnVector) { + ListColumnVector listVector = (ListColumnVector) vector; + OrcColumnVector dataVector = toOrcColumnVector( + ((ArrayType) type).elementType(), listVector.child); + return new OrcArrayColumnVector( + type, vector, dataVector, listVector.offsets, listVector.lengths); + } else if (vector instanceof MapColumnVector) { + MapColumnVector mapVector = (MapColumnVector) vector; + MapType mapType = (MapType) type; + OrcColumnVector keysVector = toOrcColumnVector(mapType.keyType(), mapVector.keys); + OrcColumnVector valuesVector = toOrcColumnVector(mapType.valueType(), mapVector.values); + return new OrcMapColumnVector( + type, vector, keysVector, valuesVector, mapVector.offsets, mapVector.lengths); + } else { + throw new IllegalArgumentException( + String.format("OrcColumnVectorUtils.toOrcColumnVector should not take %s as type " + + "and %s as vector", type, vector)); + } + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 6a4b116cdef0b..40ed0b2454c12 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -180,7 +180,8 @@ public void initBatch( missingCol.setIsConstant(); orcVectorWrappers[i] = missingCol; } else { - orcVectorWrappers[i] = new OrcColumnVector(dt, wrap.batch().cols[colId]); + orcVectorWrappers[i] = OrcColumnVectorUtils.toOrcColumnVector( + dt, wrap.batch().cols[colId]); } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcMapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcMapColumnVector.java new file mode 100644 index 0000000000000..ace8d157792dc --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcMapColumnVector.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column vector implementation for Spark's {@link MapType}. + */ +public class OrcMapColumnVector extends OrcColumnVector { + private final OrcColumnVector keys; + private final OrcColumnVector values; + private final long[] offsets; + private final long[] lengths; + + OrcMapColumnVector( + DataType type, + ColumnVector vector, + OrcColumnVector keys, + OrcColumnVector values, + long[] offsets, + long[] lengths) { + + super(type, vector); + + this.keys = keys; + this.values = values; + this.offsets = offsets; + this.lengths = lengths; + } + + @Override + public ColumnarMap getMap(int ordinal) { + return new ColumnarMap(keys, values, (int) offsets[ordinal], (int) lengths[ordinal]); + } + + @Override + public boolean getBoolean(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException(); + } + + @Override + public UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getBinary(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcStructColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcStructColumnVector.java new file mode 100644 index 0000000000000..48e540d22095e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcStructColumnVector.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column vector implementation for Spark's {@link StructType}. + */ +public class OrcStructColumnVector extends OrcColumnVector { + private final OrcColumnVector[] fields; + + OrcStructColumnVector(DataType type, ColumnVector vector, OrcColumnVector[] fields) { + super(type, vector); + + this.fields = fields; + } + + @Override + public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { + return fields[ordinal]; + } + + @Override + public boolean getBoolean(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException(); + } + + @Override + public UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getBinary(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 8f4d1e5098029..3a5097441ab37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -131,11 +131,27 @@ class OrcFileFormat } } + private def supportBatchForNestedColumn( + sparkSession: SparkSession, + schema: StructType): Boolean = { + val hasNestedColumn = schema.map(_.dataType).exists { + case _: ArrayType | _: MapType | _: StructType => true + case _ => false + } + if (hasNestedColumn) { + sparkSession.sessionState.conf.orcVectorizedReaderNestedColumnEnabled + } else { + true + } + } + override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { val conf = sparkSession.sessionState.conf conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled && schema.length <= conf.wholeStageMaxNumFields && - schema.forall(_.dataType.isInstanceOf[AtomicType]) + schema.forall(s => supportDataType(s.dataType) && + !s.dataType.isInstanceOf[UserDefinedType[_]]) && + supportBatchForNestedColumn(sparkSession, schema) } override def isSplitable( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index c763f4c9428c8..ad69bee4fb643 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -33,6 +33,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} import org.apache.spark.sql.{Row, SPARK_VERSION_METADATA_KEY} +import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite, SchemaMergeUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -542,6 +543,7 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll with CommonFileDa } class OrcSourceSuite extends OrcSuite with SharedSparkSession { + import testImplicits._ protected override def beforeAll(): Unit = { super.beforeAll() @@ -602,4 +604,31 @@ class OrcSourceSuite extends OrcSuite with SharedSparkSession { checkAnswer(spark.read.orc(path), Seq(Row(0), Row(1), Row(2))) } } + + test("SPARK-34862: Support ORC vectorized reader for nested column") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = spark.range(10).map { x => + val stringColumn = s"$x" * 10 + val structColumn = (x, s"$x" * 100) + val arrayColumn = (0 until 5).map(i => (x + i, s"$x" * 5)) + val mapColumn = Map( + s"$x" -> (x * 0.1, (x, s"$x" * 100)), + (s"$x" * 2) -> (x * 0.2, (x, s"$x" * 200)), + (s"$x" * 3) -> (x * 0.3, (x, s"$x" * 300))) + (x, stringColumn, structColumn, arrayColumn, mapColumn) + }.toDF("int_col", "string_col", "struct_col", "array_col", "map_col") + df.write.format("orc").save(path) + + withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") { + val readDf = spark.read.orc(path) + val vectorizationEnabled = readDf.queryExecution.executedPlan.find { + case scan: FileSourceScanExec => scan.supportsColumnar + case _ => false + }.isDefined + assert(vectorizationEnabled) + checkAnswer(readDf, df) + } + } + } }