Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
liancheng committed Mar 24, 2016
1 parent 75dc296 commit 057b6f2
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {

val bucketedRDD = new UnionRDD(t.sqlContext.sparkContext,
(0 until spec.numBuckets).map { bucketId =>
bucketedDataMap.get(bucketId).getOrElse {
t.sqlContext.emptyResult: RDD[InternalRow]
}
bucketedDataMap.getOrElse(bucketId, t.sqlContext.emptyResult: RDD[InternalRow])
})
bucketedRDD
}
Expand Down Expand Up @@ -387,7 +385,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
result.setColumn(resultIdx, input.column(inputIdx))
inputIdx += 1
} else {
require(partitionColumnSchema.fields.filter(_.name.equals(attr.name)).length == 1)
require(partitionColumnSchema.fields.count(_.name == attr.name) == 1)
var partitionIdx = 0
partitionColumnSchema.fields.foreach { f => {
if (f.name.equals(attr.name)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ case class PartitionedFile(
filePath: String,
start: Long,
length: Long) {
override def toString(): String = {
override def toString: String = {
s"path: $filePath, range: $start-${start + length}, partition values: $partitionValues"
}
}
Expand All @@ -44,7 +44,7 @@ case class PartitionedFile(
*
* TODO: This currently does not take locality information about the files into account.
*/
case class FilePartition(val index: Int, files: Seq[PartitionedFile]) extends Partition
case class FilePartition(index: Int, files: Seq[PartitionedFile]) extends Partition

class FileScanRDD(
@transient val sqlContext: SQLContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{DataSourceScan, SparkPlan}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._

/**
* A strategy for planning scans over collections of files that might be partitioned or bucketed
Expand All @@ -56,9 +55,10 @@ import org.apache.spark.sql.types._
*/
private[sql] object FileSourceStrategy extends Strategy with Logging {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(projects, filters, l@LogicalRelation(files: HadoopFsRelation, _, _))
case PhysicalOperation(projects, filters, l @ LogicalRelation(files: HadoopFsRelation, _, _))
if (files.fileFormat.toString == "TestFileFormat" ||
files.fileFormat.isInstanceOf[parquet.DefaultSource]) &&
files.fileFormat.isInstanceOf[parquet.DefaultSource] ||
files.fileFormat.toString == "ORC") &&
files.sqlContext.conf.parquetFileScan =>
// Filters on this relation fall into four categories based on where we can use them to avoid
// reading unneeded data:
Expand All @@ -81,10 +81,10 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
val bucketColumns =
AttributeSet(
files.bucketSpec
.map(_.bucketColumnNames)
.getOrElse(Nil)
.map(l.resolveQuoted(_, files.sqlContext.conf.resolver)
.getOrElse(sys.error(""))))
.map(_.bucketColumnNames)
.getOrElse(Nil)
.map(l.resolveQuoted(_, files.sqlContext.conf.resolver)
.getOrElse(sys.error(""))))

// Partition keys are not available in the statistics of the files.
val dataFilters = filters.filter(_.references.intersect(partitionSet).isEmpty)
Expand All @@ -101,8 +101,8 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {

val readDataColumns =
dataColumns
.filter(requiredAttributes.contains)
.filterNot(partitionColumns.contains)
.filter(requiredAttributes.contains)
.filterNot(partitionColumns.contains)
val prunedDataSchema = readDataColumns.toStructType
logInfo(s"Pruned Data Schema: ${prunedDataSchema.simpleString(5)}")

Expand All @@ -120,13 +120,12 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
case Some(bucketing) if files.sqlContext.conf.bucketingEnabled =>
logInfo(s"Planning with ${bucketing.numBuckets} buckets")
val bucketed =
selectedPartitions
.flatMap { p =>
p.files.map(f => PartitionedFile(p.values, f.getPath.toUri.toString, 0, f.getLen))
}.groupBy { f =>
selectedPartitions.flatMap { p =>
p.files.map(f => PartitionedFile(p.values, f.getPath.toUri.toString, 0, f.getLen))
}.groupBy { f =>
BucketingUtils
.getBucketId(new Path(f.filePath).getName)
.getOrElse(sys.error(s"Invalid bucket file ${f.filePath}"))
.getBucketId(new Path(f.filePath).getName)
.getOrElse(sys.error(s"Invalid bucket file ${f.filePath}"))
}

(0 until bucketing.numBuckets).map { bucketId =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,11 @@ private[sql] class DefaultSource
// Try to push down filters when filter push-down is enabled.
val pushed = if (sqlContext.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key).toBoolean) {
filters
// Collects all converted Parquet filter predicates. Notice that not all predicates can be
// converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap`
// is used here.
.flatMap(ParquetFilters.createFilter(dataSchema, _))
.reduceOption(FilterApi.and)
// Collects all converted Parquet filter predicates. Notice that not all predicates can be
// converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap`
// is used here.
.flatMap(ParquetFilters.createFilter(dataSchema, _))
.reduceOption(FilterApi.and)
} else {
None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.execution.datasources.text

import com.google.common.base.Objects
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{LongWritable, NullWritable, Text}
import org.apache.hadoop.mapred.{JobConf, TextInputFormat}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector

import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.hive.HiveMetastoreTypes
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -92,7 +91,6 @@ private[orc] object OrcFileOperator extends Logging {
// TODO: Check if the paths coming in are already qualified and simplify.
val origPath = new Path(pathStr)
val fs = origPath.getFileSystem(conf)
val path = origPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath)
.filterNot(_.isDirectory)
.map(_.getPath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.orc

import java.util.Properties

import scala.collection.JavaConverters._

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
Expand All @@ -27,10 +29,10 @@ import org.apache.hadoop.hive.ql.io.orc.OrcFile.OrcTableProperties
import org.apache.hadoop.hive.serde2.objectinspector.SettableStructObjectInspector
import org.apache.hadoop.hive.serde2.typeinfo.{StructTypeInfo, TypeInfoUtils}
import org.apache.hadoop.io.{NullWritable, Writable}
import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter}
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
import org.apache.hadoop.mapred.{JobConf, RecordWriter, Reporter, InputFormat => MapRedInputFormat, OutputFormat => MapRedOutputFormat}
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat

import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import org.apache.hadoop.mapreduce._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.{HadoopRDD, RDD}
Expand All @@ -44,7 +46,8 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.collection.BitSet

private[sql] class DefaultSource extends FileFormat with DataSourceRegister {
private[sql] class DefaultSource
extends FileFormat with DataSourceRegister with HiveInspectors with Serializable {

override def shortName(): String = "orc"

Expand All @@ -55,7 +58,9 @@ private[sql] class DefaultSource extends FileFormat with DataSourceRegister {
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
OrcFileOperator.readSchema(
files.map(_.getPath.toUri.toString), Some(sqlContext.sparkContext.hadoopConfiguration))
files.map(_.getPath.toUri.toString),
Some(sqlContext.sparkContext.hadoopConfiguration)
)
}

override def prepareWrite(
Expand All @@ -80,8 +85,8 @@ private[sql] class DefaultSource extends FileFormat with DataSourceRegister {
job.getConfiguration.set(
OrcTableProperties.COMPRESSION.getPropName,
OrcRelation
.shortOrcCompressionCodecNames
.getOrElse(codecName, CompressionKind.NONE).name())
.shortOrcCompressionCodecNames
.getOrElse(codecName, CompressionKind.NONE).name())
}

job.getConfiguration match {
Expand Down Expand Up @@ -117,6 +122,87 @@ private[sql] class DefaultSource extends FileFormat with DataSourceRegister {
val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes
OrcTableScan(sqlContext, output, filters, inputFiles).execute()
}

override def buildReader(
sqlContext: SQLContext,
partitionSchema: StructType,
dataSchema: StructType,
filters: Seq[Filter],
options: Map[String, String]): (PartitionedFile) => Iterator[InternalRow] = {
val orcConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration)

if (sqlContext.conf.orcFilterPushDown) {
// Sets pushed predicates
OrcFilters.createFilter(filters.toArray).foreach { f =>
orcConf.set(OrcTableScan.SARG_PUSHDOWN, f.toKryo)
orcConf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true)
}
}

val broadcastedConf = sqlContext.sparkContext.broadcast(new SerializableConfiguration(orcConf))
// Temporary variables used to avoid serialization issue
val _dataSchema = dataSchema

(file: PartitionedFile) => {
val conf = broadcastedConf.value.value

// Sets required columns
// TODO De-duplicates this part and `OrcFileScan.addColumnIds`
val physicalSchema = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf)).getOrElse {
sys.error("Failed to read schema from target ORC files.")
}
val ids = _dataSchema.map(a => physicalSchema.fieldIndex(a.name): Integer)
val (sortedIDs, sortedNames) = ids.zip(_dataSchema.fieldNames).sorted.unzip
HiveShim.appendReadColumns(conf, sortedIDs, sortedNames)

val recordReaders = {
val job = Job.getInstance(conf)
FileInputFormat.setInputPaths(job, file.filePath)

val inputFormat = new OrcNewInputFormat
val splits = inputFormat.getSplits(job)

splits.asScala.map { split =>
val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
inputFormat.createRecordReader(split, hadoopAttemptContext)
}
}

val orcStructIterator = recordReaders.iterator.flatMap {
new RecordReaderIterator[OrcStruct](_)
}

// TODO De-duplicates this part and `OrcFileScan.fillObject`
val deserializer = new OrcSerde
val maybeStructOI = OrcFileOperator.getObjectInspector(file.filePath, Some(conf))
val mutableRow = new SpecificMutableRow(_dataSchema.map(_.dataType))
val unsafeProjection = UnsafeProjection.create(_dataSchema)

maybeStructOI.map { oi =>
val (fieldRefs, fieldOrdinals) = _dataSchema.zipWithIndex.map {
case (field, ordinal) => oi.getStructFieldRef(field.name) -> ordinal
}.unzip

val unwrappers = fieldRefs.map(unwrapperFor)

orcStructIterator.map { value =>
val raw = deserializer.deserialize(value)
var i = 0
while (i < fieldRefs.length) {
val fieldValue = oi.getStructFieldData(raw, fieldRefs(i))
if (fieldValue == null) {
mutableRow.setNullAt(fieldOrdinals(i))
} else {
unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i))
}
i += 1
}
unsafeProjection(mutableRow)
}
}.getOrElse(Iterator.empty)
}
}
}

private[orc] class OrcOutputWriter(
Expand Down Expand Up @@ -291,8 +377,8 @@ private[orc] case class OrcTableScan(
val orcFormat = new DefaultSource
val dataSchema =
orcFormat
.inferSchema(sqlContext, Map.empty, inputPaths)
.getOrElse(sys.error("Failed to read schema from target ORC files."))
.inferSchema(sqlContext, Map.empty, inputPaths)
.getOrElse(sys.error("Failed to read schema from target ORC files."))
// Sets requested columns
addColumnIds(dataSchema, attributes, conf)

Expand Down

0 comments on commit 057b6f2

Please sign in to comment.