Skip to content

Commit

Permalink
Adding unit test for filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreSchumacher committed May 16, 2014
1 parent 6d22666 commit a93a588
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import parquet.hadoop.util.ContextUtil
import parquet.io.InvalidRecordException
import parquet.schema.MessageType

import org.apache.spark.{SerializableWritable, SparkContext, TaskContext}
import org.apache.spark.{Logging, SerializableWritable, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{BinaryComparison, Attribute, Expression, Row}
import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
Expand Down Expand Up @@ -78,8 +78,13 @@ case class ParquetTableScan(
ParquetFilters.serializeFilterExpressions(columnPruningPred.get, conf)
}

sc.newAPIHadoopRDD(conf, classOf[org.apache.spark.sql.parquet.FilteringParquetRowInputFormat], classOf[Void], classOf[Row])
.map(_._2)
sc.newAPIHadoopRDD(
conf,
classOf[org.apache.spark.sql.parquet.FilteringParquetRowInputFormat],
classOf[Void],
classOf[Row])
.map(_._2)
.filter(_ != null) // Parquet's record filters may produce null values
}

override def otherCopyArgs = sc :: Nil
Expand Down Expand Up @@ -270,12 +275,17 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int)

// We extend ParquetInputFormat in order to have more control over which
// RecordFilter we want to use
private[parquet] class FilteringParquetRowInputFormat extends parquet.hadoop.ParquetInputFormat[Row] {
override def createRecordReader(inputSplit: InputSplit, taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = {
private[parquet] class FilteringParquetRowInputFormat
extends parquet.hadoop.ParquetInputFormat[Row] with Logging {
override def createRecordReader(
inputSplit: InputSplit,
taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = {
val readSupport: ReadSupport[Row] = new RowReadSupport()

val filterExpressions = ParquetFilters.deserializeFilterExpressions(ContextUtil.getConfiguration(taskAttemptContext))
val filterExpressions =
ParquetFilters.deserializeFilterExpressions(ContextUtil.getConfiguration(taskAttemptContext))
if (filterExpressions.isDefined) {
logInfo(s"Pushing down predicates for RecordFilter: ${filterExpressions.mkString(", ")}")
new ParquetRecordReader[Row](readSupport, ParquetFilters.createFilter(filterExpressions.get))
} else {
new ParquetRecordReader[Row](readSupport)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ import parquet.io.api.RecordConsumer
import parquet.hadoop.api.WriteSupport.WriteContext
import parquet.example.data.simple.SimpleGroup

// Write support class for nested groups:
// ParquetWriter initializes GroupWriteSupport with an empty configuration
// (it is after all not intended to be used in this way?)
// and members are private so we need to make our own
// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport
// with an empty configuration (it is after all not intended to be used in this way?)
// and members are private so we need to make our own in order to pass the schema
// to the writer.
private class TestGroupWriteSupport(schema: MessageType) extends WriteSupport[Group] {
var groupWriter: GroupWriter = null
override def prepareForWrite(recordConsumer: RecordConsumer): Unit = {
Expand Down Expand Up @@ -81,58 +81,74 @@ private[sql] object ParquetTestData {
|}
""".stripMargin

val testFilterSchema =
"""
|message myrecord {
|required boolean myboolean;
|required int32 myint;
|required binary mystring;
|required int64 mylong;
|required float myfloat;
|required double mydouble;
|}
""".stripMargin

// field names for test assertion error messages
val subTestSchemaFieldNames = Seq(
"myboolean:Boolean",
"mylong:Long"
)

val testDir = Utils.createTempDir()
val testFilterDir = Utils.createTempDir()

lazy val testData = new ParquetRelation(testDir.toURI.toString)

def writeFile() = {
testDir.delete
val path: Path = new Path(new Path(testDir.toURI), new Path("part-r-0.parquet"))
val job = new Job()
val configuration: Configuration = ContextUtil.getConfiguration(job)
val schema: MessageType = MessageTypeParser.parseMessageType(testSchema)

//val writeSupport = new MutableRowWriteSupport()
//writeSupport.setSchema(schema, configuration)
//val writer = new ParquetWriter(path, writeSupport)
val writeSupport = new TestGroupWriteSupport(schema)
//val writer = //new ParquetWriter[Group](path, writeSupport)
val writer = new ParquetWriter[Group](path, writeSupport)

for(i <- 0 until 15) {
val record = new SimpleGroup(schema)
//val data = new Array[Any](6)
if (i % 3 == 0) {
//data.update(0, true)
record.add(0, true)
} else {
//data.update(0, false)
record.add(0, false)
}
if (i % 5 == 0) {
record.add(1, 5)
// data.update(1, 5)
} else {
if (i % 5 == 1) record.add(1, 4)
}
//else {
// data.update(1, null) // optional
//}
//data.update(2, "abc")
record.add(2, "abc")
//data.update(3, i.toLong << 33)
record.add(3, i.toLong << 33)
//data.update(4, 2.5F)
record.add(4, 2.5F)
//data.update(5, 4.5D)
record.add(5, 4.5D)
//writer.write(new GenericRow(data.toArray))
writer.write(record)
}
writer.close()
}

def writeFilterFile() = {
testFilterDir.delete
val path: Path = new Path(new Path(testFilterDir.toURI), new Path("part-r-0.parquet"))
val schema: MessageType = MessageTypeParser.parseMessageType(testFilterSchema)
val writeSupport = new TestGroupWriteSupport(schema)
val writer = new ParquetWriter[Group](path, writeSupport)

for(i <- 0 to 200) {
val record = new SimpleGroup(schema)
if (i % 4 == 0) {
record.add(0, true)
} else {
record.add(0, false)
}
record.add(1, i)
record.add(2, i.toString)
record.add(3, i.toLong)
record.add(4, i.toFloat + 0.5f)
record.add(5, i.toDouble + 0.5d)
writer.write(record)
}
writer.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,22 @@

package org.apache.spark.sql.parquet

import java.io.File

import org.scalatest.{BeforeAndAfterAll, FunSuite}

import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.hadoop.mapreduce.Job

import parquet.hadoop.ParquetFileWriter
import parquet.schema.MessageTypeParser
import parquet.hadoop.util.ContextUtil
import parquet.schema.MessageTypeParser

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.util.getTempFilePath
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row}
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.TestData
import org.apache.spark.sql.SchemaRDD
import org.apache.spark.util.Utils
import org.apache.spark.sql.catalyst.types.{StringType, IntegerType, DataType}
import org.apache.spark.sql.{parquet, SchemaRDD}

// Implicits
import org.apache.spark.sql.test.TestSQLContext._
Expand Down Expand Up @@ -64,12 +61,16 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {

override def beforeAll() {
ParquetTestData.writeFile()
ParquetTestData.writeFilterFile()
testRDD = parquetFile(ParquetTestData.testDir.toString)
testRDD.registerAsTable("testsource")
parquetFile(ParquetTestData.testFilterDir.toString)
.registerAsTable("testfiltersource")
}

override def afterAll() {
Utils.deleteRecursively(ParquetTestData.testDir)
Utils.deleteRecursively(ParquetTestData.testFilterDir)
// here we should also unregister the table??
}

Expand Down Expand Up @@ -256,5 +257,19 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {
assert(result != null)
}*/
}

test("test filter by predicate pushdown") {
for(myval <- Seq("myint", "mylong", "mydouble", "myfloat")) {
println(s"testing field $myval")
val result1 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100").collect()
assert(result1.size === 50)
val result2 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200").collect()
assert(result2.size === 50)
}
val booleanResult = sql(s"SELECT * FROM testfiltersource WHERE myboolean = true AND myint < 40").collect()
assert(booleanResult.size === 10)
val stringResult = sql("SELECT * FROM testfiltersource WHERE mystring = \"100\"").collect()
assert(stringResult.size === 1)
}
}

0 comments on commit a93a588

Please sign in to comment.