From 837839ee95b6c3d9dc035c717348fcb23286b951 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Wed, 11 Jan 2017 18:19:56 +0000 Subject: [PATCH 1/4] DatasourceScanExec uses runtime sparksession --- .../sql/execution/DataSourceScanExec.scala | 75 ++++++++++++------- .../datasources/DataSourceStrategy.scala | 3 +- .../datasources/FileSourceStrategy.scala | 5 +- .../PruneFileSourcePartitions.scala | 3 +- .../parquet/ParquetQuerySuite.scala | 33 +++++++- .../sql/sources/HadoopFsRelationTest.scala | 29 ++++++- 6 files changed, 116 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 39b010efec7b0..5bd76a94704d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.execution +import java.util.concurrent.{Callable, TimeUnit} + import scala.collection.mutable.ArrayBuffer +import com.google.common.cache.{Cache, CacheBuilder} import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.{BlockLocation, FileStatus, LocatedFileStatus, Path} @@ -147,23 +150,26 @@ case class FileSourceScanExec( override val metastoreTableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with ColumnarBatchScan { - val supportsBatch: Boolean = relation.fileFormat.supportBatch( - relation.sparkSession, StructType.fromAttributes(output)) + def supportsBatch: Boolean = relation.fileFormat.supportBatch( + sparkSession, StructType.fromAttributes(output)) - val needsUnsafeRowConversion: Boolean = if (relation.fileFormat.isInstanceOf[ParquetSource]) { - SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled + def needsUnsafeRowConversion: Boolean = if (relation.fileFormat.isInstanceOf[ParquetSource]) { + sparkSession.sessionState.conf.parquetVectorizedReaderEnabled } else { false } + def sparkSession: SparkSession = SparkSession.getActiveSession.get + + private val readerCache: Cache[SparkSession, RDD[InternalRow]] = + CacheBuilder.newBuilder() + .expireAfterAccess(4, TimeUnit.HOURS) + .build() + @transient private lazy val selectedPartitions = relation.location.listFiles(partitionFilters) - override val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = { - val bucketSpec = if (relation.sparkSession.sessionState.conf.bucketingEnabled) { - relation.bucketSpec - } else { - None - } + private def partitioningAndOrder( + bucketSpec: Option[BucketSpec]): (Partitioning, Seq[SortOrder]) = { bucketSpec match { case Some(spec) => // For bucketed columns: @@ -225,6 +231,17 @@ case class FileSourceScanExec( } } + private def bucketSpec: Option[BucketSpec] = + if (sparkSession.sessionState.conf.bucketingEnabled) { + relation.bucketSpec + } else { + None + } + + override def outputPartitioning: Partitioning = partitioningAndOrder(bucketSpec)._1 + + override def outputOrdering: Seq[SortOrder] = partitioningAndOrder(bucketSpec)._2 + // These metadata values make scan plans uniquely identifiable for equality checking. override val metadata: Map[String, String] = { def seqToString(seq: Seq[Any]) = seq.mkString("[", ", ", "]") @@ -248,25 +265,31 @@ case class FileSourceScanExec( withOptPartitionCount } - private lazy val inputRDD: RDD[InternalRow] = { - val readFile: (PartitionedFile) => Iterator[InternalRow] = - relation.fileFormat.buildReaderWithPartitionValues( - sparkSession = relation.sparkSession, - dataSchema = relation.dataSchema, - partitionSchema = relation.partitionSchema, - requiredSchema = outputSchema, - filters = dataFilters, - options = relation.options, - hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) + private def inputRDDInternal(sparkSession: SparkSession): RDD[InternalRow] = { + val readFile = relation.fileFormat.buildReaderWithPartitionValues( + sparkSession = sparkSession, + dataSchema = relation.dataSchema, + partitionSchema = relation.partitionSchema, + requiredSchema = outputSchema, + filters = dataFilters, + options = relation.options, + hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) relation.bucketSpec match { - case Some(bucketing) if relation.sparkSession.sessionState.conf.bucketingEnabled => + case Some(bucketing) if sparkSession.sessionState.conf.bucketingEnabled => createBucketedReadRDD(bucketing, readFile, selectedPartitions, relation) case _ => createNonBucketedReadRDD(readFile, selectedPartitions, relation) } } + private def inputRDD: RDD[InternalRow] = { + val sparkSession = sparkSession + readerCache.get(sparkSession, new Callable[RDD[InternalRow]] { + override def call(): RDD[InternalRow] = inputRDDInternal(sparkSession) + }) + } + override def inputRDDs(): Seq[RDD[InternalRow]] = { inputRDD :: Nil } @@ -370,7 +393,7 @@ case class FileSourceScanExec( FilePartition(bucketId, bucketed.getOrElse(bucketId, Nil)) } - new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions) + new FileScanRDD(sparkSession, readFile, filePartitions) } /** @@ -385,10 +408,10 @@ case class FileSourceScanExec( readFile: (PartitionedFile) => Iterator[InternalRow], selectedPartitions: Seq[PartitionDirectory], fsRelation: HadoopFsRelation): RDD[InternalRow] = { - val defaultMaxSplitBytes = - fsRelation.sparkSession.sessionState.conf.filesMaxPartitionBytes - val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes - val defaultParallelism = fsRelation.sparkSession.sparkContext.defaultParallelism + val session = sparkSession + val defaultMaxSplitBytes = session.sessionState.conf.filesMaxPartitionBytes + val openCostInBytes = session.sessionState.conf.filesOpenCostInBytes + val defaultParallelism = session.sparkContext.defaultParallelism val totalBytes = selectedPartitions.flatMap(_.files.map(_.getLen + openCostInBytes)).sum val bytesPerCore = totalBytes / defaultParallelism diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 19db293132f54..a4a7bb25eefaf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -176,8 +176,9 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { "Cannot overwrite a path that is also being read from.") } + val sparkSession = SparkSession.getActiveSession.get val partitionSchema = actualQuery.resolve( - t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver) + t.partitionSchema, sparkSession.sessionState.analyzer.resolver) val staticPartitions = parts.filter(_._2.nonEmpty).map { case (k, v) => k -> v.get } InsertIntoHadoopFsRelationCommand( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 26e1380eca499..5ec706be4009b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -71,16 +71,17 @@ object FileSourceStrategy extends Strategy with Logging { } } + val sparkSession = SparkSession.getActiveSession.get val partitionColumns = l.resolve( - fsRelation.partitionSchema, fsRelation.sparkSession.sessionState.analyzer.resolver) + fsRelation.partitionSchema, sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") val dataColumns = - l.resolve(fsRelation.dataSchema, fsRelation.sparkSession.sessionState.analyzer.resolver) + l.resolve(fsRelation.dataSchema, sparkSession.sessionState.analyzer.resolver) // Partition keys are not available in the statistics of the files. val dataFilters = normalizedFilters.filter(_.references.intersect(partitionSet).isEmpty) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 8566a8061034b..29177701a18c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} @@ -47,7 +48,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { } } - val sparkSession = fsRelation.sparkSession + val sparkSession = SparkSession.getActiveSession.get val partitionColumns = logicalRelation.resolve( partitionSchema, sparkSession.sessionState.analyzer.resolver) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index d7d7176c48a3a..069284a1ca66b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.execution.FileSourceScanExec -import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol +import org.apache.spark.sql.execution.datasources.{FileScanRDD, SQLHadoopMapReduceCommitProtocol} import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT, SingleElement} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -738,6 +738,37 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } } + + test("DataSourceScanExec uses active spark session") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { + withTempPath { dir => + dir.getAbsoluteFile + val path = "file://" + dir.getCanonicalPath + spark.range(4).coalesce(1).write.parquet(path) + val df = spark.read.parquet(path) + + val Some((scan1, fileScanRDD1)) = df.queryExecution.executedPlan.collectFirst { + case scan: FileSourceScanExec if scan.inputRDDs().head.isInstanceOf[FileScanRDD] => + (scan, scan.inputRDDs().head.asInstanceOf[FileScanRDD]) + } + + val supportsBatchInitially = scan1.supportsBatch + + val newSession = spark.newSession() + newSession.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "false") + SparkSession.setActiveSession(newSession) + val Some((scan2, fileScanRDD2)) = df.queryExecution.executedPlan.collectFirst { + case scan: FileSourceScanExec if scan.inputRDDs().head.isInstanceOf[FileScanRDD] => + (scan, scan.inputRDDs().head.asInstanceOf[FileScanRDD]) + } + + assert(scan1 == scan2) + assert(supportsBatchInitially) + assert(supportsBatchInitially != scan2.supportsBatch) + assert(fileScanRDD1 != fileScanRDD2) + } + } + } } object TestingUDT { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index d23b66a5300e7..3fbe728f77a8a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -28,7 +28,7 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ -import org.apache.spark.sql.execution.DataSourceScanExec +import org.apache.spark.sql.execution.{DataSourceScanExec, FileSourceScanExec} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf @@ -897,6 +897,33 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes checkAnswer(df, readBack) } } + + test("DataSourceScanExec uses active session upon execution") { + withTempPath { dir => + val path = "file://" + dir.getCanonicalPath + spark.range(4).coalesce(1).write.format(dataSourceName).save(path) + val df = spark.read.format(dataSourceName).load(path) + val Some(scan1) = df.queryExecution.executedPlan.collectFirst { + case scan: FileSourceScanExec => scan + } + + val newSession = spark.newSession() + + val session1 = scan1.sparkSession + SparkSession.setActiveSession(newSession) + val Some(scan2) = df.queryExecution.executedPlan.collectFirst { + case scan: FileSourceScanExec => scan + } + + val session2 = scan2.sparkSession + + assert(scan1 == scan2) + assert(session1 == spark) + assert(session2 == newSession) + + SparkSession.setActiveSession(spark) + } + } } // This class is used to test SPARK-8578. We should not use any custom output committer when From b4075333a6c24f64c4772975a87b3abf51b56742 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Sat, 14 Jan 2017 12:02:19 +0000 Subject: [PATCH 2/4] fix recursive variables --- .../org/apache/spark/sql/execution/DataSourceScanExec.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 5bd76a94704d9..0458a18f643e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -284,7 +284,6 @@ case class FileSourceScanExec( } private def inputRDD: RDD[InternalRow] = { - val sparkSession = sparkSession readerCache.get(sparkSession, new Callable[RDD[InternalRow]] { override def call(): RDD[InternalRow] = inputRDDInternal(sparkSession) }) From 1e9fd71dbeebfdcf5956e9c6fb20abb536f2c7e2 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Thu, 26 Jan 2017 14:17:44 +0000 Subject: [PATCH 3/4] remove caching --- .../sql/execution/DataSourceScanExec.scala | 124 ++++++++---------- 1 file changed, 57 insertions(+), 67 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 0458a18f643e4..e4c748602734d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -17,27 +17,24 @@ package org.apache.spark.sql.execution -import java.util.concurrent.{Callable, TimeUnit} - import scala.collection.mutable.ArrayBuffer -import com.google.common.cache.{Cache, CacheBuilder} import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.{BlockLocation, FileStatus, LocatedFileStatus, Path} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{BaseRelation, Filter} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils trait DataSourceScanExec extends LeafExecNode with CodegenSupport { @@ -161,87 +158,86 @@ case class FileSourceScanExec( def sparkSession: SparkSession = SparkSession.getActiveSession.get - private val readerCache: Cache[SparkSession, RDD[InternalRow]] = - CacheBuilder.newBuilder() - .expireAfterAccess(4, TimeUnit.HOURS) - .build() - @transient private lazy val selectedPartitions = relation.location.listFiles(partitionFilters) - private def partitioningAndOrder( - bucketSpec: Option[BucketSpec]): (Partitioning, Seq[SortOrder]) = { + private def bucketSpec: Option[BucketSpec] = + if (sparkSession.sessionState.conf.bucketingEnabled) { + relation.bucketSpec + } else { + None + } + + // For bucketed columns: + // ----------------------- + // `HashPartitioning` would be used only when: + // 1. ALL the bucketing columns are being read from the table + // + // For sorted columns: + // --------------------- + // Sort ordering should be used when ALL these criteria's match: + // 1. `HashPartitioning` is being used + // 2. A prefix (or all) of the sort columns are being read from the table. + // + // Sort ordering would be over the prefix subset of `sort columns` being read + // from the table. + // eg. + // Assume (col0, col2, col3) are the columns read from the table + // If sort columns are (col0, col1), then sort ordering would be considered as (col0) + // If sort columns are (col1, col0), then sort ordering would be empty as per rule #2 + // above + override def outputPartitioning: Partitioning = { + bucketSpec match { + case Some(spec) => + def toAttribute(colName: String): Option[Attribute] = output.find(_.name == colName) + + val bucketColumns = spec.bucketColumnNames.flatMap(toAttribute) + if (bucketColumns.size == spec.bucketColumnNames.size) { + HashPartitioning(bucketColumns, spec.numBuckets) + } else { + UnknownPartitioning(0) + } + case _ => + UnknownPartitioning(0) + } + } + + override def outputOrdering: Seq[SortOrder] = { bucketSpec match { case Some(spec) => - // For bucketed columns: - // ----------------------- - // `HashPartitioning` would be used only when: - // 1. ALL the bucketing columns are being read from the table - // - // For sorted columns: - // --------------------- - // Sort ordering should be used when ALL these criteria's match: - // 1. `HashPartitioning` is being used - // 2. A prefix (or all) of the sort columns are being read from the table. - // - // Sort ordering would be over the prefix subset of `sort columns` being read - // from the table. - // eg. - // Assume (col0, col2, col3) are the columns read from the table - // If sort columns are (col0, col1), then sort ordering would be considered as (col0) - // If sort columns are (col1, col0), then sort ordering would be empty as per rule #2 - // above - - def toAttribute(colName: String): Option[Attribute] = - output.find(_.name == colName) - - val bucketColumns = spec.bucketColumnNames.flatMap(n => toAttribute(n)) + def toAttribute(colName: String): Option[Attribute] = output.find(_.name == colName) + + val bucketColumns = spec.bucketColumnNames.flatMap(toAttribute) if (bucketColumns.size == spec.bucketColumnNames.size) { - val partitioning = HashPartitioning(bucketColumns, spec.numBuckets) - val sortColumns = - spec.sortColumnNames.map(x => toAttribute(x)).takeWhile(x => x.isDefined).map(_.get) + val sortColumns = spec.sortColumnNames.map(toAttribute).takeWhile(_.isDefined).map(_.get) - val sortOrder = if (sortColumns.nonEmpty) { + if (sortColumns.nonEmpty) { // In case of bucketing, its possible to have multiple files belonging to the // same bucket in a given relation. Each of these files are locally sorted // but those files combined together are not globally sorted. Given that, // the RDD partition will not be sorted even if the relation has sort columns set // Current solution is to check if all the buckets have a single file in it - val files = selectedPartitions.flatMap(partition => partition.files) + val files = selectedPartitions.flatMap(_.files) val bucketToFilesGrouping = - files.map(_.getPath.getName).groupBy(file => BucketingUtils.getBucketId(file)) - val singleFilePartitions = bucketToFilesGrouping.forall(p => p._2.length <= 1) - - if (singleFilePartitions) { + files.map(_.getPath.getName).groupBy(BucketingUtils.getBucketId) + if (bucketToFilesGrouping.forall(_._2.length <= 1)) { // TODO Currently Spark does not support writing columns sorting in descending order // so using Ascending order. This can be fixed in future - sortColumns.map(attribute => SortOrder(attribute, Ascending)) + sortColumns.map(SortOrder(_, Ascending)) } else { Nil } } else { Nil } - (partitioning, sortOrder) } else { - (UnknownPartitioning(0), Nil) + Nil } case _ => - (UnknownPartitioning(0), Nil) + Nil } } - private def bucketSpec: Option[BucketSpec] = - if (sparkSession.sessionState.conf.bucketingEnabled) { - relation.bucketSpec - } else { - None - } - - override def outputPartitioning: Partitioning = partitioningAndOrder(bucketSpec)._1 - - override def outputOrdering: Seq[SortOrder] = partitioningAndOrder(bucketSpec)._2 - // These metadata values make scan plans uniquely identifiable for equality checking. override val metadata: Map[String, String] = { def seqToString(seq: Seq[Any]) = seq.mkString("[", ", ", "]") @@ -265,7 +261,7 @@ case class FileSourceScanExec( withOptPartitionCount } - private def inputRDDInternal(sparkSession: SparkSession): RDD[InternalRow] = { + private def inputRDD: RDD[InternalRow] = { val readFile = relation.fileFormat.buildReaderWithPartitionValues( sparkSession = sparkSession, dataSchema = relation.dataSchema, @@ -283,12 +279,6 @@ case class FileSourceScanExec( } } - private def inputRDD: RDD[InternalRow] = { - readerCache.get(sparkSession, new Callable[RDD[InternalRow]] { - override def call(): RDD[InternalRow] = inputRDDInternal(sparkSession) - }) - } - override def inputRDDs(): Seq[RDD[InternalRow]] = { inputRDD :: Nil } From 72ee9b3d31502a1bf259a30e511174984b8444c5 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Mon, 30 Jan 2017 00:10:41 +0000 Subject: [PATCH 4/4] define dataSchema for simplehadoopfsrelationtests --- .../sql/sources/HadoopFsRelationTest.scala | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 3fbe728f77a8a..b655e2bc5920d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -20,14 +20,12 @@ package org.apache.spark.sql.sources import java.io.File import scala.util.Random - import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import org.apache.parquet.hadoop.ParquetOutputCommitter - import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql._ +import org.apache.spark.sql.{DataFrame, _} import org.apache.spark.sql.execution.{DataSourceScanExec, FileSourceScanExec} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -900,10 +898,18 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes test("DataSourceScanExec uses active session upon execution") { withTempPath { dir => - val path = "file://" + dir.getCanonicalPath - spark.range(4).coalesce(1).write.format(dataSourceName).save(path) - val df = spark.read.format(dataSourceName).load(path) - val Some(scan1) = df.queryExecution.executedPlan.collectFirst { + val childDir = new File(dir, dataSourceName).getCanonicalPath + val dataDf = spark.range(4).coalesce(1).toDF() + dataDf.write.format(dataSourceName).save(childDir) + val reader = spark.read.format(dataSourceName) + + // This is needed for SimpleTextHadoopFsRelationSuite as SimpleTextSource needs schema. + if (dataSourceName == classOf[SimpleTextSource].getCanonicalName) { + reader.option("dataSchema", dataDf.schema.json) + } + + val readDf = reader.load(childDir) + val Some(scan1) = readDf.queryExecution.executedPlan.collectFirst { case scan: FileSourceScanExec => scan } @@ -911,7 +917,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes val session1 = scan1.sparkSession SparkSession.setActiveSession(newSession) - val Some(scan2) = df.queryExecution.executedPlan.collectFirst { + val Some(scan2) = readDf.queryExecution.executedPlan.collectFirst { case scan: FileSourceScanExec => scan }