Skip to content
Closed
1 change: 1 addition & 0 deletions .rat-excludes
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,4 @@ INDEX
.lintr
gen-java.*
.*avpr
org.apache.spark.sql.sources.DataSourceRegister
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
org.apache.spark.sql.jdbc.DefaultSource
org.apache.spark.sql.json.DefaultSource
org.apache.spark.sql.parquet.DefaultSource
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should orc be added as well ?
I see change to OrcRelation.scala below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Orc is added in the other resource file since hive is a sperate package.

Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@

package org.apache.spark.sql.execution.datasources

import java.util.ServiceLoader

import scala.collection.Iterator
import scala.collection.JavaConversions._
import scala.language.{existentials, implicitConversions}
import scala.util.{Failure, Success, Try}
import scala.util.matching.Regex

import org.apache.hadoop.fs.Path
Expand Down Expand Up @@ -190,37 +195,32 @@ private[sql] class DDLParser(
}
}

private[sql] object ResolvedDataSource {

private val builtinSources = Map(
"jdbc" -> "org.apache.spark.sql.jdbc.DefaultSource",
"json" -> "org.apache.spark.sql.json.DefaultSource",
"parquet" -> "org.apache.spark.sql.parquet.DefaultSource",
"orc" -> "org.apache.spark.sql.hive.orc.DefaultSource"
)
private[sql] object ResolvedDataSource extends Logging {

/** Given a provider name, look up the data source class definition. */
def lookupDataSource(provider: String): Class[_] = {
val provider2 = s"$provider.DefaultSource"
val loader = Utils.getContextOrSparkClassLoader

if (builtinSources.contains(provider)) {
return loader.loadClass(builtinSources(provider))
}

try {
loader.loadClass(provider)
} catch {
case cnf: java.lang.ClassNotFoundException =>
try {
loader.loadClass(provider + ".DefaultSource")
} catch {
case cnf: java.lang.ClassNotFoundException =>
if (provider.startsWith("org.apache.spark.sql.hive.orc")) {
sys.error("The ORC data source must be used with Hive support enabled.")
} else {
sys.error(s"Failed to load class for data source: $provider")
}
val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader)

serviceLoader.iterator().filter(_.format().equalsIgnoreCase(provider)).toList match {
/** the provider format did not match any given registered aliases */
case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match {
case Success(dataSource) => dataSource
case Failure(error) => if (provider.startsWith("org.apache.spark.sql.hive.orc")) {
throw new ClassNotFoundException(
"The ORC data source must be used with Hive support enabled.", error)
} else {
throw new ClassNotFoundException(
s"Failed to load class for data source: $provider", error)
}
}
/** there is exactly one registered alias */
case head :: Nil => head.getClass
/** There are multiple registered aliases for the input */
case sources => sys.error(s"Multiple sources found for $provider, " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is printing an error enough? If it is this should most probably be logError. But it feels like this should be an exception.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sys.eror generates a runtime exception, is that what you want?

s"(${sources.map(_.getClass.getName).mkString(", ")}), " +
"please specify the fully qualified class name")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ private[sql] object JDBCRelation {
}
}

private[sql] class DefaultSource extends RelationProvider {
private[sql] class DefaultSource extends RelationProvider with DataSourceRegister {

def format(): String = "jdbc"

/** Returns a new base relation with the given parameters. */
override def createRelation(
sqlContext: SQLContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{AnalysisException, Row, SQLContext}

private[sql] class DefaultSource extends HadoopFsRelationProvider {
private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {

def format(): String = "json"

override def createRelation(
sqlContext: SQLContext,
paths: Array[String],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.{SerializableConfiguration, Utils}


private[sql] class DefaultSource extends HadoopFsRelationProvider {
private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {

def format(): String = "parquet"

override def createRelation(
sqlContext: SQLContext,
paths: Array[String],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,27 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.sql._
import org.apache.spark.util.SerializableConfiguration

/**
* ::DeveloperApi::
* Data sources should implement this trait so that they can register an alias to their data source.
* This allows users to give the data source alias as the format type over the fully qualified
* class name.
*
* ex: parquet.DefaultSource.format = "parquet".
*
* A new instance of this class with be instantiated each time a DDL call is made.
*/
@DeveloperApi
trait DataSourceRegister {

/**
* The string that represents the format that this data source provider uses. This is
* overridden by children to provide a nice alias for the data source,
* ex: override def format(): String = "parquet"
*/
def format(): String
}

/**
* ::DeveloperApi::
* Implemented by objects that produce relations for a specific kind of data source. When
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
org.apache.spark.sql.sources.FakeSourceOne
org.apache.spark.sql.sources.FakeSourceTwo
org.apache.spark.sql.sources.FakeSourceThree
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.sources

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}

class FakeSourceOne extends RelationProvider with DataSourceRegister {

def format(): String = "Fluet da Bomb"

override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
new BaseRelation {
override def sqlContext: SQLContext = cont

override def schema: StructType =
StructType(Seq(StructField("stringType", StringType, nullable = false)))
}
}

class FakeSourceTwo extends RelationProvider with DataSourceRegister {

def format(): String = "Fluet da Bomb"

override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
new BaseRelation {
override def sqlContext: SQLContext = cont

override def schema: StructType =
StructType(Seq(StructField("stringType", StringType, nullable = false)))
}
}

class FakeSourceThree extends RelationProvider with DataSourceRegister {

def format(): String = "gathering quorum"

override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
new BaseRelation {
override def sqlContext: SQLContext = cont

override def schema: StructType =
StructType(Seq(StructField("stringType", StringType, nullable = false)))
}
}
// please note that the META-INF/services had to be modified for the test directory for this to work
class DDLSourceLoadSuite extends DataSourceTest {

test("data sources with the same name") {
intercept[RuntimeException] {
caseInsensitiveContext.read.format("Fluet da Bomb").load()
}
}

test("load data source from format alias") {
caseInsensitiveContext.read.format("gathering quorum").load().schema ==
StructType(Seq(StructField("stringType", StringType, nullable = false)))
}

test("specify full classname with duplicate formats") {
caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne")
.load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should add a test for fail loading the orc with SQLContext, and also with HiveContext.


test("Loading Orc") {
intercept[ClassNotFoundException] {
caseInsensitiveContext.read.format("orc").load()
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
org.apache.spark.sql.hive.orc.DefaultSource
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ import org.apache.spark.util.SerializableConfiguration
/* Implicit conversions */
import scala.collection.JavaConversions._

private[sql] class DefaultSource extends HadoopFsRelationProvider {
private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {

def format(): String = "orc"

def createRelation(
sqlContext: SQLContext,
paths: Array[String],
Expand Down