Skip to content

Commit

Permalink
[SPARK-24924][SQL] Add mapping for built-in Avro data source
Browse files Browse the repository at this point in the history
  • Loading branch information
dongjoon-hyun committed Jul 26, 2018
1 parent 0c83f71 commit d95ba40
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
Expand Up @@ -33,6 +33,7 @@ import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
import org.apache.commons.io.FileUtils

import org.apache.spark.sql._
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.sql.types._

Expand All @@ -51,6 +52,13 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
checkAnswer(newEntries, originalEntries)
}

test("resolve avro data source") {
Seq("avro", "com.databricks.spark.avro").foreach { provider =>
assert(DataSource.lookupDataSource(provider, spark.sessionState.conf) ===
classOf[org.apache.spark.sql.avro.AvroFileFormat])
}
}

test("reading from multiple paths") {
val df = spark.read.format("avro").load(episodesAvro, episodesAvro)
assert(df.count == 16)
Expand Down Expand Up @@ -456,7 +464,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
// get the same values back.
withTempPath { tempDir =>
val name = "AvroTest"
val namespace = "com.databricks.spark.avro"
val namespace = "org.apache.spark.avro"
val parameters = Map("recordName" -> name, "recordNamespace" -> namespace)

val avroDir = tempDir + "/namedAvro"
Expand Down
Expand Up @@ -571,6 +571,7 @@ object DataSource extends Logging {
val nativeOrc = classOf[OrcFileFormat].getCanonicalName
val socket = classOf[TextSocketSourceProvider].getCanonicalName
val rate = classOf[RateStreamProvider].getCanonicalName
val avro = "org.apache.spark.sql.avro.AvroFileFormat"

Map(
"org.apache.spark.sql.jdbc" -> jdbc,
Expand All @@ -592,6 +593,7 @@ object DataSource extends Logging {
"org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm,
"org.apache.spark.ml.source.libsvm" -> libsvm,
"com.databricks.spark.csv" -> csv,
"com.databricks.spark.avro" -> avro,
"org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket,
"org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate
)
Expand Down Expand Up @@ -635,12 +637,6 @@ object DataSource extends Logging {
"Hive built-in ORC data source must be used with Hive support enabled. " +
"Please use the native ORC data source by setting 'spark.sql.orc.impl' to " +
"'native'")
} else if (provider1.toLowerCase(Locale.ROOT) == "avro" ||
provider1 == "com.databricks.spark.avro") {
throw new AnalysisException(
s"Failed to find data source: ${provider1.toLowerCase(Locale.ROOT)}. " +
"Please find an Avro package at " +
"http://spark.apache.org/third-party-projects.html")
} else {
throw new ClassNotFoundException(
s"Failed to find data source: $provider1. Please find packages at " +
Expand Down

0 comments on commit d95ba40

Please sign in to comment.