diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 14f8659f15b3f..2e0c6c51c00e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce.{JobContext, InputSplit, Job} +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import parquet.hadoop.ParquetInputFormat import parquet.hadoop.util.ContextUtil @@ -31,8 +32,8 @@ import org.apache.spark.{Partition => SparkPartition, Logging} import org.apache.spark.rdd.{NewHadoopPartition, RDD} import org.apache.spark.sql.{SQLConf, Row, SQLContext} -import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, And, Expression, Attribute} -import org.apache.spark.sql.catalyst.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.{StringType, IntegerType, StructField, StructType} import org.apache.spark.sql.sources._ import scala.collection.JavaConversions._ @@ -151,8 +152,6 @@ case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) override def buildScan(output: Seq[Attribute], predicates: Seq[Expression]): RDD[Row] = { // This is mostly a hack so that we can use the existing parquet filter code. val requiredColumns = output.map(_.name) - // TODO: Parquet filters should be based on data sources API, not catalyst expressions. - val filters = DataSourceStrategy.selectFilters(predicates) val job = new Job(sparkContext.hadoopConfiguration) ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) @@ -160,35 +159,34 @@ case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) val requestedSchema = StructType(requiredColumns.map(schema(_))) - // TODO: Make folder based partitioning a first class citizen of the Data Sources API. - val partitionFilters = filters.collect { - case e @ EqualTo(attr, value) if partitionKeys.contains(attr) => - logInfo(s"Parquet scan partition filter: $attr=$value") - (p: Partition) => p.partitionValues(attr) == value - - case e @ In(attr, values) if partitionKeys.contains(attr) => - logInfo(s"Parquet scan partition filter: $attr IN ${values.mkString("{", ",", "}")}") - val set = values.toSet - (p: Partition) => set.contains(p.partitionValues(attr)) - - case e @ GreaterThan(attr, value) if partitionKeys.contains(attr) => - logInfo(s"Parquet scan partition filter: $attr > $value") - (p: Partition) => p.partitionValues(attr).asInstanceOf[Int] > value.asInstanceOf[Int] - - case e @ GreaterThanOrEqual(attr, value) if partitionKeys.contains(attr) => - logInfo(s"Parquet scan partition filter: $attr >= $value") - (p: Partition) => p.partitionValues(attr).asInstanceOf[Int] >= value.asInstanceOf[Int] + val partitionKeySet = partitionKeys.toSet + val rawPredicate = + predicates + .filter(_.references.map(_.name).toSet.subsetOf(partitionKeySet)) + .reduceOption(And) + .getOrElse(Literal(true)) + + // Translate the predicate so that it reads from the information derived from the + // folder structure + val castedPredicate = rawPredicate transform { + case a: AttributeReference => + val idx = partitionKeys.indexWhere(a.name == _) + BoundReference(idx, IntegerType, nullable = true) + } - case e @ LessThan(attr, value) if partitionKeys.contains(attr) => - logInfo(s"Parquet scan partition filter: $attr < $value") - (p: Partition) => p.partitionValues(attr).asInstanceOf[Int] < value.asInstanceOf[Int] + val inputData = new GenericMutableRow(partitionKeys.size) + val pruningCondition = InterpretedPredicate(castedPredicate) - case e @ LessThanOrEqual(attr, value) if partitionKeys.contains(attr) => - logInfo(s"Parquet scan partition filter: $attr <= $value") - (p: Partition) => p.partitionValues(attr).asInstanceOf[Int] <= value.asInstanceOf[Int] - } + val selectedPartitions = + if (partitionKeys.nonEmpty && predicates.nonEmpty) { + partitions.filter { part => + inputData(0) = part.partitionValues.values.head + pruningCondition(inputData) + } + } else { + partitions + } - val selectedPartitions = partitions.filter(p => partitionFilters.forall(_(p))) val fs = FileSystem.get(new java.net.URI(path), sparkContext.hadoopConfiguration) val selectedFiles = selectedPartitions.flatMap(_.files).map(f => fs.makeQualified(f.getPath)) // FileInputFormat cannot handle empty lists.