diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ff84f4efbd81b..e4c58efa59d23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -566,6 +566,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ORC_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.orc.mergeSchema") + .doc("When true, the Orc data source merges schemas collected from all data files, " + + "otherwise the schema is picked from a random data file.") + .booleanConf + .createWithDefault(false) + val HIVE_VERIFY_PARTITION_PATH = buildConf("spark.sql.hive.verifyPartitionPath") .doc("When true, check all the partition paths under the table\'s root directory " + "when reading data stored in HDFS. This configuration will be deprecated in the future " + @@ -1956,6 +1962,8 @@ class SQLConf extends Serializable with Logging { def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) + def isOrcSchemaMergingEnabled: Boolean = getConf(ORC_SCHEMA_MERGING_ENABLED) + def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala new file mode 100644 index 0000000000000..99882b0f7c7b0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} + +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +object SchemaMergeUtils extends Logging { + /** + * Figures out a merged Parquet/ORC schema with a distributed Spark job. + */ + def mergeSchemasInParallel( + sparkSession: SparkSession, + files: Seq[FileStatus], + schemaReader: (Seq[FileStatus], Configuration, Boolean) => Seq[StructType]) + : Option[StructType] = { + val serializedConf = new SerializableConfiguration(sparkSession.sessionState.newHadoopConf()) + + // !! HACK ALERT !! + // Here is a hack for Parquet, but it can be used by Orc as well. + // + // Parquet requires `FileStatus`es to read footers. + // Here we try to send cached `FileStatus`es to executor side to avoid fetching them again. + // However, `FileStatus` is not `Serializable` + // but only `Writable`. What makes it worse, for some reason, `FileStatus` doesn't play well + // with `SerializableWritable[T]` and always causes a weird `IllegalStateException`. These + // facts virtually prevents us to serialize `FileStatus`es. + // + // Since Parquet only relies on path and length information of those `FileStatus`es to read + // footers, here we just extract them (which can be easily serialized), send them to executor + // side, and resemble fake `FileStatus`es there. + val partialFileStatusInfo = files.map(f => (f.getPath.toString, f.getLen)) + + // Set the number of partitions to prevent following schema reads from generating many tasks + // in case of a small number of orc files. + val numParallelism = Math.min(Math.max(partialFileStatusInfo.size, 1), + sparkSession.sparkContext.defaultParallelism) + + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles + + // Issues a Spark job to read Parquet/ORC schema in parallel. + val partiallyMergedSchemas = + sparkSession + .sparkContext + .parallelize(partialFileStatusInfo, numParallelism) + .mapPartitions { iterator => + // Resembles fake `FileStatus`es with serialized path and length information. + val fakeFileStatuses = iterator.map { case (path, length) => + new FileStatus(length, false, 0, 0, 0, 0, null, null, null, new Path(path)) + }.toSeq + + val schemas = schemaReader(fakeFileStatuses, serializedConf.value, ignoreCorruptFiles) + + if (schemas.isEmpty) { + Iterator.empty + } else { + var mergedSchema = schemas.head + schemas.tail.foreach { schema => + try { + mergedSchema = mergedSchema.merge(schema) + } catch { case cause: SparkException => + throw new SparkException( + s"Failed merging schema:\n${schema.treeString}", cause) + } + } + Iterator.single(mergedSchema) + } + }.collect() + + if (partiallyMergedSchemas.isEmpty) { + None + } else { + var finalSchema = partiallyMergedSchemas.head + partiallyMergedSchemas.tail.foreach { schema => + try { + finalSchema = finalSchema.merge(schema) + } catch { case cause: SparkException => + throw new SparkException( + s"Failed merging schema:\n${schema.treeString}", cause) + } + } + Some(finalSchema) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 01f8ce7911d4a..f7c12598da209 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -94,7 +94,7 @@ class OrcFileFormat sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - OrcUtils.readSchema(sparkSession, files) + OrcUtils.inferSchema(sparkSession, files, options) } override def prepareWrite( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala index 0ad3862f6cf01..25f022bcdde89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala @@ -57,9 +57,20 @@ class OrcOptions( } shortOrcCompressionCodecNames(codecName) } + + /** + * Whether it merges schemas or not. When the given Orc files have different schemas, + * the schemas can be merged. By default use the value specified in SQLConf. + */ + val mergeSchema: Boolean = parameters + .get(MERGE_SCHEMA) + .map(_.toBoolean) + .getOrElse(sqlConf.isOrcSchemaMergingEnabled) } object OrcOptions { + val MERGE_SCHEMA = "mergeSchema" + // The ORC compression short names private val shortOrcCompressionCodecNames = Map( "none" -> "NONE", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index fb9f87ccddddf..12d4244e19812 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -33,7 +33,9 @@ import org.apache.spark.sql.{SPARK_VERSION_METADATA_KEY, SparkSession} import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.execution.datasources.SchemaMergeUtils import org.apache.spark.sql.types._ +import org.apache.spark.util.{SerializableConfiguration, ThreadUtils} object OrcUtils extends Logging { @@ -82,7 +84,6 @@ object OrcUtils extends Logging { : Option[StructType] = { val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles val conf = sparkSession.sessionState.newHadoopConf() - // TODO: We need to support merge schema. Please see SPARK-11412. files.toIterator.map(file => readSchema(file.getPath, conf, ignoreCorruptFiles)).collectFirst { case Some(schema) => logDebug(s"Reading schema from file $files, got Hive schema string: $schema") @@ -90,6 +91,29 @@ object OrcUtils extends Logging { } } + /** + * Reads ORC file schemas in multi-threaded manner, using native version of ORC. + * This is visible for testing. + */ + def readOrcSchemasInParallel( + files: Seq[FileStatus], conf: Configuration, ignoreCorruptFiles: Boolean): Seq[StructType] = { + ThreadUtils.parmap(files, "readingOrcSchemas", 8) { currentFile => + OrcUtils.readSchema(currentFile.getPath, conf, ignoreCorruptFiles) + .map(s => CatalystSqlParser.parseDataType(s.toString).asInstanceOf[StructType]) + }.flatten + } + + def inferSchema(sparkSession: SparkSession, files: Seq[FileStatus], options: Map[String, String]) + : Option[StructType] = { + val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) + if (orcOptions.mergeSchema) { + SchemaMergeUtils.mergeSchemasInParallel( + sparkSession, files, OrcUtils.readOrcSchemasInParallel) + } else { + OrcUtils.readSchema(sparkSession, files) + } + } + /** * Returns the requested column ids from the given ORC file. Column id can be -1, which means the * requested column doesn't exist in the ORC file. Returns None if the given ORC file is empty. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 1af6745525b84..9caa34b2d9652 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -476,79 +476,18 @@ object ParquetFileFormat extends Logging { sparkSession: SparkSession): Option[StructType] = { val assumeBinaryIsString = sparkSession.sessionState.conf.isParquetBinaryAsString val assumeInt96IsTimestamp = sparkSession.sessionState.conf.isParquetINT96AsTimestamp - val serializedConf = new SerializableConfiguration(sparkSession.sessionState.newHadoopConf()) - - // !! HACK ALERT !! - // - // Parquet requires `FileStatus`es to read footers. Here we try to send cached `FileStatus`es - // to executor side to avoid fetching them again. However, `FileStatus` is not `Serializable` - // but only `Writable`. What makes it worse, for some reason, `FileStatus` doesn't play well - // with `SerializableWritable[T]` and always causes a weird `IllegalStateException`. These - // facts virtually prevents us to serialize `FileStatus`es. - // - // Since Parquet only relies on path and length information of those `FileStatus`es to read - // footers, here we just extract them (which can be easily serialized), send them to executor - // side, and resemble fake `FileStatus`es there. - val partialFileStatusInfo = filesToTouch.map(f => (f.getPath.toString, f.getLen)) - - // Set the number of partitions to prevent following schema reads from generating many tasks - // in case of a small number of parquet files. - val numParallelism = Math.min(Math.max(partialFileStatusInfo.size, 1), - sparkSession.sparkContext.defaultParallelism) - - val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles - - // Issues a Spark job to read Parquet schema in parallel. - val partiallyMergedSchemas = - sparkSession - .sparkContext - .parallelize(partialFileStatusInfo, numParallelism) - .mapPartitions { iterator => - // Resembles fake `FileStatus`es with serialized path and length information. - val fakeFileStatuses = iterator.map { case (path, length) => - new FileStatus(length, false, 0, 0, 0, 0, null, null, null, new Path(path)) - }.toSeq - - // Reads footers in multi-threaded manner within each task - val footers = - ParquetFileFormat.readParquetFootersInParallel( - serializedConf.value, fakeFileStatuses, ignoreCorruptFiles) - - // Converter used to convert Parquet `MessageType` to Spark SQL `StructType` - val converter = new ParquetToSparkSchemaConverter( - assumeBinaryIsString = assumeBinaryIsString, - assumeInt96IsTimestamp = assumeInt96IsTimestamp) - if (footers.isEmpty) { - Iterator.empty - } else { - var mergedSchema = ParquetFileFormat.readSchemaFromFooter(footers.head, converter) - footers.tail.foreach { footer => - val schema = ParquetFileFormat.readSchemaFromFooter(footer, converter) - try { - mergedSchema = mergedSchema.merge(schema) - } catch { case cause: SparkException => - throw new SparkException( - s"Failed merging schema of file ${footer.getFile}:\n${schema.treeString}", cause) - } - } - Iterator.single(mergedSchema) - } - }.collect() - if (partiallyMergedSchemas.isEmpty) { - None - } else { - var finalSchema = partiallyMergedSchemas.head - partiallyMergedSchemas.tail.foreach { schema => - try { - finalSchema = finalSchema.merge(schema) - } catch { case cause: SparkException => - throw new SparkException( - s"Failed merging schema:\n${schema.treeString}", cause) - } - } - Some(finalSchema) + val reader = (files: Seq[FileStatus], conf: Configuration, ignoreCorruptFiles: Boolean) => { + // Converter used to convert Parquet `MessageType` to Spark SQL `StructType` + val converter = new ParquetToSparkSchemaConverter( + assumeBinaryIsString = assumeBinaryIsString, + assumeInt96IsTimestamp = assumeInt96IsTimestamp) + + readParquetFootersInParallel(conf, files, ignoreCorruptFiles) + .map(ParquetFileFormat.readSchemaFromFooter(_, converter)) } + + SchemaMergeUtils.mergeSchemasInParallel(sparkSession, filesToTouch, reader) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala index 1cc6e61c845c0..3fe433861a3c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.execution.datasources.v2.orc +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.SparkSession @@ -39,7 +41,7 @@ case class OrcTable( new OrcScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) override def inferSchema(files: Seq[FileStatus]): Option[StructType] = - OrcUtils.readSchema(sparkSession, files) + OrcUtils.inferSchema(sparkSession, files, options.asScala.toMap) override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = new OrcWriteBuilder(options, paths, formatName, supportsDataType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala index 8c95349ef3be7..d5502ba5737c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.SparkConf import org.apache.spark.sql.internal.SQLConf /** @@ -32,6 +33,7 @@ import org.apache.spark.sql.internal.SQLConf * * -> OrcReadSchemaSuite * -> VectorizedOrcReadSchemaSuite + * -> MergedOrcReadSchemaSuite * * -> ParquetReadSchemaSuite * -> VectorizedParquetReadSchemaSuite @@ -134,6 +136,25 @@ class VectorizedOrcReadSchemaSuite } } +class MergedOrcReadSchemaSuite + extends ReadSchemaSuite + with AddColumnIntoTheMiddleTest + with HideColumnInTheMiddleTest + with AddNestedColumnTest + with HideNestedColumnTest + with ChangePositionTest + with BooleanTypeTest + with IntegralTypeTest + with ToDoubleTypeTest { + + override val format: String = "orc" + + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.ORC_SCHEMA_MERGING_ENABLED.key, "true") +} + class ParquetReadSchemaSuite extends ReadSchemaSuite with AddColumnIntoTheMiddleTest diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 8f9cc629880eb..c9f5d9cb23e6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -23,7 +23,7 @@ import java.sql.Timestamp import java.util.Locale import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.orc.OrcConf.COMPRESS import org.apache.orc.OrcFile import org.apache.orc.OrcProto.ColumnEncoding.Kind.{DICTIONARY_V2, DIRECT, DIRECT_V2} @@ -31,10 +31,12 @@ import org.apache.orc.OrcProto.Stream.Kind import org.apache.orc.impl.RecordReaderImpl import org.scalatest.BeforeAndAfterAll -import org.apache.spark.SPARK_VERSION_SHORT +import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} import org.apache.spark.sql.{Row, SPARK_VERSION_METADATA_KEY} +import org.apache.spark.sql.execution.datasources.SchemaMergeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{LongType, StructField, StructType} import org.apache.spark.util.Utils case class OrcData(intField: Int, stringField: String) @@ -188,6 +190,49 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { } } + protected def testMergeSchemasInParallel( + ignoreCorruptFiles: Boolean, + schemaReader: (Seq[FileStatus], Configuration, Boolean) => Seq[StructType]): Unit = { + withSQLConf( + SQLConf.IGNORE_CORRUPT_FILES.key -> ignoreCorruptFiles.toString, + SQLConf.ORC_IMPLEMENTATION.key -> orcImp) { + withTempDir { dir => + val fs = FileSystem.get(spark.sessionState.newHadoopConf()) + val basePath = dir.getCanonicalPath + + val path1 = new Path(basePath, "first") + val path2 = new Path(basePath, "second") + val path3 = new Path(basePath, "third") + + spark.range(1).toDF("a").coalesce(1).write.orc(path1.toString) + spark.range(1, 2).toDF("b").coalesce(1).write.orc(path2.toString) + spark.range(2, 3).toDF("a").coalesce(1).write.json(path3.toString) + + val fileStatuses = + Seq(fs.listStatus(path1), fs.listStatus(path2), fs.listStatus(path3)).flatten + + val schema = SchemaMergeUtils.mergeSchemasInParallel( + spark, + fileStatuses, + schemaReader) + + assert(schema.isDefined) + assert(schema.get == StructType(Seq( + StructField("a", LongType, true), + StructField("b", LongType, true)))) + } + } + } + + protected def testMergeSchemasInParallel( + schemaReader: (Seq[FileStatus], Configuration, Boolean) => Seq[StructType]): Unit = { + testMergeSchemasInParallel(true, schemaReader) + val exception = intercept[SparkException] { + testMergeSchemasInParallel(false, schemaReader) + }.getCause + assert(exception.getCause.getMessage.contains("Could not read footer for file")) + } + test("create temporary orc table") { checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(10)) @@ -332,6 +377,107 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { assert(version === SPARK_VERSION_SHORT) } } + + test("SPARK-11412 test orc merge schema option") { + val conf = spark.sessionState.conf + // Test if the default of spark.sql.orc.mergeSchema is false + assert(new OrcOptions(Map.empty[String, String], conf).mergeSchema == false) + + // OrcOptions's parameters have a higher priority than SQL configuration. + // `mergeSchema` -> `spark.sql.orc.mergeSchema` + withSQLConf(SQLConf.ORC_SCHEMA_MERGING_ENABLED.key -> "true") { + val map1 = Map(OrcOptions.MERGE_SCHEMA -> "true") + val map2 = Map(OrcOptions.MERGE_SCHEMA -> "false") + assert(new OrcOptions(map1, conf).mergeSchema == true) + assert(new OrcOptions(map2, conf).mergeSchema == false) + } + + withSQLConf(SQLConf.ORC_SCHEMA_MERGING_ENABLED.key -> "false") { + val map1 = Map(OrcOptions.MERGE_SCHEMA -> "true") + val map2 = Map(OrcOptions.MERGE_SCHEMA -> "false") + assert(new OrcOptions(map1, conf).mergeSchema == true) + assert(new OrcOptions(map2, conf).mergeSchema == false) + } + } + + test("SPARK-11412 test enabling/disabling schema merging") { + def testSchemaMerging(expectedColumnNumber: Int): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(0, 10).toDF("a").write.orc(new Path(basePath, "foo=1").toString) + spark.range(0, 10).toDF("b").write.orc(new Path(basePath, "foo=2").toString) + assert(spark.read.orc(basePath).columns.length === expectedColumnNumber) + + // OrcOptions.MERGE_SCHEMA has higher priority + assert(spark.read.option(OrcOptions.MERGE_SCHEMA, true) + .orc(basePath).columns.length === 3) + assert(spark.read.option(OrcOptions.MERGE_SCHEMA, false) + .orc(basePath).columns.length === 2) + } + } + + withSQLConf(SQLConf.ORC_SCHEMA_MERGING_ENABLED.key -> "true") { + testSchemaMerging(3) + } + + withSQLConf(SQLConf.ORC_SCHEMA_MERGING_ENABLED.key -> "false") { + testSchemaMerging(2) + } + } + + test("SPARK-11412 test enabling/disabling schema merging with data type conflicts") { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(0, 10).toDF("a").write.orc(new Path(basePath, "foo=1").toString) + spark.range(0, 10).map(s => s"value_$s").toDF("a") + .write.orc(new Path(basePath, "foo=2").toString) + + // with schema merging, there should throw exception + withSQLConf(SQLConf.ORC_SCHEMA_MERGING_ENABLED.key -> "true") { + val exception = intercept[SparkException] { + spark.read.orc(basePath).columns.length + }.getCause + + val innerMessage = orcImp match { + case "native" => exception.getMessage + case "hive" => exception.getCause.getMessage + case impl => + throw new UnsupportedOperationException(s"Unknown ORC implementation: $impl") + } + + assert(innerMessage.contains("Failed to merge incompatible data types")) + } + + // it is ok if no schema merging + withSQLConf(SQLConf.ORC_SCHEMA_MERGING_ENABLED.key -> "false") { + assert(spark.read.orc(basePath).columns.length === 2) + } + } + } + + test("SPARK-11412 test schema merging with corrupt files") { + withSQLConf(SQLConf.ORC_SCHEMA_MERGING_ENABLED.key -> "true") { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(0, 10).toDF("a").write.orc(new Path(basePath, "foo=1").toString) + spark.range(0, 10).toDF("b").write.orc(new Path(basePath, "foo=2").toString) + spark.range(0, 10).toDF("c").write.json(new Path(basePath, "foo=3").toString) + + // ignore corrupt files + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "true") { + assert(spark.read.orc(basePath).columns.length === 3) + } + + // don't ignore corrupt files + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { + val exception = intercept[SparkException] { + spark.read.orc(basePath).columns.length + }.getCause + assert(exception.getCause.getMessage.contains("Could not read footer for file")) + } + } + } + } } class OrcSourceSuite extends OrcSuite with SharedSQLContext { @@ -377,4 +523,8 @@ class OrcSourceSuite extends OrcSuite with SharedSQLContext { test("Enforce direct encoding column-wise selectively") { testSelectiveDictionaryEncoding(isSelective = true) } + + test("SPARK-11412 read and merge orc schemas in parallel") { + testMergeSchemasInParallel(OrcUtils.readOrcSchemasInParallel) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 9ac3e98f5f0b7..7f2eb14956dc1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -67,12 +67,20 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles - OrcFileOperator.readSchema( - files.map(_.getPath.toString), - Some(sparkSession.sessionState.newHadoopConf()), - ignoreCorruptFiles - ) + val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) + if (orcOptions.mergeSchema) { + SchemaMergeUtils.mergeSchemasInParallel( + sparkSession, + files, + OrcFileOperator.readOrcSchemasInParallel) + } else { + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles + OrcFileOperator.readSchema( + files.map(_.getPath.toString), + Some(sparkSession.sessionState.newHadoopConf()), + ignoreCorruptFiles + ) + } } override def prepareWrite( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 713b70f252b6a..1a5f47bf5aa7d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive.orc import java.io.IOException import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hive.ql.io.orc.{OrcFile, Reader} import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector @@ -29,6 +29,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ThreadUtils private[hive] object OrcFileOperator extends Logging { /** @@ -101,6 +102,24 @@ private[hive] object OrcFileOperator extends Logging { } } + /** + * Reads ORC file schemas in multi-threaded manner, using Hive ORC library. + * This is visible for testing. + */ + def readOrcSchemasInParallel( + partFiles: Seq[FileStatus], conf: Configuration, ignoreCorruptFiles: Boolean) + : Seq[StructType] = { + ThreadUtils.parmap(partFiles, "readingOrcSchemas", 8) { currentFile => + val file = currentFile.getPath.toString + getFileReader(file, Some(conf), ignoreCorruptFiles).map(reader => { + val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] + val schema = readerInspector.getTypeName + logDebug(s"Reading schema from file $file., got Hive schema string: $schema") + CatalystSqlParser.parseDataType(schema).asInstanceOf[StructType] + }) + }.flatten + } + def getObjectInspector( path: String, conf: Option[Configuration]): Option[StructObjectInspector] = { getFileReader(path, conf).map(_.getObjectInspector.asInstanceOf[StructObjectInspector]) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index 6bcb2225e66d4..3104fb4d8173c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -166,4 +166,8 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { } } } + + test("SPARK-11412 read and merge orc schemas in parallel") { + testMergeSchemasInParallel(OrcFileOperator.readOrcSchemasInParallel) + } }