Skip to content

Commit

Permalink
[SPARK-5775] [SQL] BugFix: GenericRow cannot be cast to SpecificMutab…
Browse files Browse the repository at this point in the history
…leRow when nested data and partitioned table

This PR adapts anselmevignon's #4697 to master and branch-1.3. Please refer to PR description of #4697 for details.

<!-- Reviewable:start -->
[<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/4792)
<!-- Reviewable:end -->

Author: Cheng Lian <lian@databricks.com>
Author: Cheng Lian <liancheng@users.noreply.github.com>
Author: Yin Huai <yhuai@databricks.com>

Closes #4792 from liancheng/spark-5775 and squashes the following commits:

538f506 [Cheng Lian] Addresses comments
cee55cf [Cheng Lian] Merge pull request #4 from yhuai/spark-5775-yin
b0b74fb [Yin Huai] Remove runtime pattern matching.
ca6e038 [Cheng Lian] Fixes SPARK-5775
  • Loading branch information
liancheng committed Feb 28, 2015
1 parent 9168259 commit e6003f0
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,13 @@ private[sql] case class ParquetTableScan(
conf)

if (requestedPartitionOrdinals.nonEmpty) {
// This check is based on CatalystConverter.createRootConverter.
val primitiveRow = output.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType))

// Uses temporary variable to avoid the whole `ParquetTableScan` object being captured into
// the `mapPartitionsWithInputSplit` closure below.
val outputSize = output.size

baseRDD.mapPartitionsWithInputSplit { case (split, iter) =>
val partValue = "([^=]+)=([^=]+)".r
val partValues =
Expand All @@ -143,19 +150,47 @@ private[sql] case class ParquetTableScan(
relation.partitioningAttributes
.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow))

new Iterator[Row] {
def hasNext = iter.hasNext
def next() = {
val row = iter.next()._2.asInstanceOf[SpecificMutableRow]

// Parquet will leave partitioning columns empty, so we fill them in here.
var i = 0
while (i < requestedPartitionOrdinals.size) {
row(requestedPartitionOrdinals(i)._2) =
partitionRowValues(requestedPartitionOrdinals(i)._1)
i += 1
if (primitiveRow) {
new Iterator[Row] {
def hasNext = iter.hasNext
def next() = {
// We are using CatalystPrimitiveRowConverter and it returns a SpecificMutableRow.
val row = iter.next()._2.asInstanceOf[SpecificMutableRow]

// Parquet will leave partitioning columns empty, so we fill them in here.
var i = 0
while (i < requestedPartitionOrdinals.size) {
row(requestedPartitionOrdinals(i)._2) =
partitionRowValues(requestedPartitionOrdinals(i)._1)
i += 1
}
row
}
}
} else {
// Create a mutable row since we need to fill in values from partition columns.
val mutableRow = new GenericMutableRow(outputSize)
new Iterator[Row] {
def hasNext = iter.hasNext
def next() = {
// We are using CatalystGroupConverter and it returns a GenericRow.
// Since GenericRow is not mutable, we just cast it to a Row.
val row = iter.next()._2.asInstanceOf[Row]

var i = 0
while (i < row.size) {
mutableRow(i) = row(i)
i += 1
}
// Parquet will leave partitioning columns empty, so we fill them in here.
i = 0
while (i < requestedPartitionOrdinals.size) {
mutableRow(requestedPartitionOrdinals(i)._2) =
partitionRowValues(requestedPartitionOrdinals(i)._1)
i += 1
}
mutableRow
}
row
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,23 +482,53 @@ private[sql] case class ParquetRelation2(
// When the data does not include the key and the key is requested then we must fill it in
// based on information from the input split.
if (!partitionKeysIncludedInDataSchema && partitionKeyLocations.nonEmpty) {
// This check is based on CatalystConverter.createRootConverter.
val primitiveRow =
requestedSchema.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType))

baseRDD.mapPartitionsWithInputSplit { case (split: ParquetInputSplit, iterator) =>
val partValues = selectedPartitions.collectFirst {
case p if split.getPath.getParent.toString == p.path => p.values
}.get

val requiredPartOrdinal = partitionKeyLocations.keys.toSeq

iterator.map { pair =>
val row = pair._2.asInstanceOf[SpecificMutableRow]
var i = 0
while (i < requiredPartOrdinal.size) {
// TODO Avoids boxing cost here!
val partOrdinal = requiredPartOrdinal(i)
row.update(partitionKeyLocations(partOrdinal), partValues(partOrdinal))
i += 1
if (primitiveRow) {
iterator.map { pair =>
// We are using CatalystPrimitiveRowConverter and it returns a SpecificMutableRow.
val row = pair._2.asInstanceOf[SpecificMutableRow]
var i = 0
while (i < requiredPartOrdinal.size) {
// TODO Avoids boxing cost here!
val partOrdinal = requiredPartOrdinal(i)
row.update(partitionKeyLocations(partOrdinal), partValues(partOrdinal))
i += 1
}
row
}
} else {
// Create a mutable row since we need to fill in values from partition columns.
val mutableRow = new GenericMutableRow(requestedSchema.size)
iterator.map { pair =>
// We are using CatalystGroupConverter and it returns a GenericRow.
// Since GenericRow is not mutable, we just cast it to a Row.
val row = pair._2.asInstanceOf[Row]
var i = 0
while (i < row.size) {
// TODO Avoids boxing cost here!
mutableRow(i) = row(i)
i += 1
}

i = 0
while (i < requiredPartOrdinal.size) {
// TODO Avoids boxing cost here!
val partOrdinal = requiredPartOrdinal(i)
mutableRow.update(partitionKeyLocations(partOrdinal), partValues(partOrdinal))
i += 1
}
mutableRow
}
row
}
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ case class ParquetData(intField: Int, stringField: String)
// The data that also includes the partitioning key
case class ParquetDataWithKey(p: Int, intField: Int, stringField: String)

case class StructContainer(intStructField :Int, stringStructField: String)

case class ParquetDataWithComplexTypes(
intField: Int,
stringField: String,
structField: StructContainer,
arrayField: Seq[Int])

case class ParquetDataWithKeyAndComplexTypes(
p: Int,
intField: Int,
stringField: String,
structField: StructContainer,
arrayField: Seq[Int])

/**
* A suite to test the automatic conversion of metastore tables with parquet data to use the
Expand Down Expand Up @@ -86,6 +100,38 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest {
location '${new File(normalTableDir, "normal").getCanonicalPath}'
""")

sql(s"""
CREATE EXTERNAL TABLE partitioned_parquet_with_complextypes
(
intField INT,
stringField STRING,
structField STRUCT<intStructField: INT, stringStructField: STRING>,
arrayField ARRAY<INT>
)
PARTITIONED BY (p int)
ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'
STORED AS
INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'
OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'
LOCATION '${partitionedTableDirWithComplexTypes.getCanonicalPath}'
""")

sql(s"""
CREATE EXTERNAL TABLE partitioned_parquet_with_key_and_complextypes
(
intField INT,
stringField STRING,
structField STRUCT<intStructField: INT, stringStructField: STRING>,
arrayField ARRAY<INT>
)
PARTITIONED BY (p int)
ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'
STORED AS
INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'
OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'
LOCATION '${partitionedTableDirWithKeyAndComplexTypes.getCanonicalPath}'
""")

(1 to 10).foreach { p =>
sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)")
}
Expand All @@ -94,7 +140,15 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest {
sql(s"ALTER TABLE partitioned_parquet_with_key ADD PARTITION (p=$p)")
}

val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
(1 to 10).foreach { p =>
sql(s"ALTER TABLE partitioned_parquet_with_key_and_complextypes ADD PARTITION (p=$p)")
}

(1 to 10).foreach { p =>
sql(s"ALTER TABLE partitioned_parquet_with_complextypes ADD PARTITION (p=$p)")
}

val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""))
jsonRDD(rdd1).registerTempTable("jt")
val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":[$i, null]}"""))
jsonRDD(rdd2).registerTempTable("jt_array")
Expand All @@ -105,6 +159,8 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest {
override def afterAll(): Unit = {
sql("DROP TABLE partitioned_parquet")
sql("DROP TABLE partitioned_parquet_with_key")
sql("DROP TABLE partitioned_parquet_with_complextypes")
sql("DROP TABLE partitioned_parquet_with_key_and_complextypes")
sql("DROP TABLE normal_parquet")
sql("DROP TABLE IF EXISTS jt")
sql("DROP TABLE IF EXISTS jt_array")
Expand Down Expand Up @@ -409,6 +465,22 @@ class ParquetSourceSuiteBase extends ParquetPartitioningTest {
path '${new File(partitionedTableDir, "p=1").getCanonicalPath}'
)
""")

sql( s"""
CREATE TEMPORARY TABLE partitioned_parquet_with_key_and_complextypes
USING org.apache.spark.sql.parquet
OPTIONS (
path '${partitionedTableDirWithKeyAndComplexTypes.getCanonicalPath}'
)
""")

sql( s"""
CREATE TEMPORARY TABLE partitioned_parquet_with_complextypes
USING org.apache.spark.sql.parquet
OPTIONS (
path '${partitionedTableDirWithComplexTypes.getCanonicalPath}'
)
""")
}

test("SPARK-6016 make sure to use the latest footers") {
Expand Down Expand Up @@ -473,7 +545,8 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
var partitionedTableDir: File = null
var normalTableDir: File = null
var partitionedTableDirWithKey: File = null

var partitionedTableDirWithComplexTypes: File = null
var partitionedTableDirWithKeyAndComplexTypes: File = null

override def beforeAll(): Unit = {
partitionedTableDir = File.createTempFile("parquettests", "sparksql")
Expand Down Expand Up @@ -509,9 +582,45 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
.toDF()
.saveAsParquetFile(partDir.getCanonicalPath)
}

partitionedTableDirWithKeyAndComplexTypes = File.createTempFile("parquettests", "sparksql")
partitionedTableDirWithKeyAndComplexTypes.delete()
partitionedTableDirWithKeyAndComplexTypes.mkdir()

(1 to 10).foreach { p =>
val partDir = new File(partitionedTableDirWithKeyAndComplexTypes, s"p=$p")
sparkContext.makeRDD(1 to 10).map { i =>
ParquetDataWithKeyAndComplexTypes(
p, i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i)
}.toDF().saveAsParquetFile(partDir.getCanonicalPath)
}

partitionedTableDirWithComplexTypes = File.createTempFile("parquettests", "sparksql")
partitionedTableDirWithComplexTypes.delete()
partitionedTableDirWithComplexTypes.mkdir()

(1 to 10).foreach { p =>
val partDir = new File(partitionedTableDirWithComplexTypes, s"p=$p")
sparkContext.makeRDD(1 to 10).map { i =>
ParquetDataWithComplexTypes(i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i)
}.toDF().saveAsParquetFile(partDir.getCanonicalPath)
}
}

override protected def afterAll(): Unit = {
partitionedTableDir.delete()
normalTableDir.delete()
partitionedTableDirWithKey.delete()
partitionedTableDirWithComplexTypes.delete()
partitionedTableDirWithKeyAndComplexTypes.delete()
}

Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table =>
Seq(
"partitioned_parquet",
"partitioned_parquet_with_key",
"partitioned_parquet_with_complextypes",
"partitioned_parquet_with_key_and_complextypes").foreach { table =>

test(s"ordering of the partitioning columns $table") {
checkAnswer(
sql(s"SELECT p, stringField FROM $table WHERE p = 1"),
Expand Down Expand Up @@ -601,6 +710,25 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
}
}

Seq(
"partitioned_parquet_with_key_and_complextypes",
"partitioned_parquet_with_complextypes").foreach { table =>

test(s"SPARK-5775 read struct from $table") {
checkAnswer(
sql(s"SELECT p, structField.intStructField, structField.stringStructField FROM $table WHERE p = 1"),
(1 to 10).map(i => Row(1, i, f"${i}_string")))
}

// Re-enable this after SPARK-5508 is fixed
ignore(s"SPARK-5775 read array from $table") {
checkAnswer(
sql(s"SELECT arrayField, p FROM $table WHERE p = 1"),
(1 to 10).map(i => Row(1 to i, 1)))
}
}


test("non-part select(*)") {
checkAnswer(
sql("SELECT COUNT(*) FROM normal_parquet"),
Expand Down

0 comments on commit e6003f0

Please sign in to comment.