Skip to content

Commit

Permalink
Add DataSource.sourceSchema
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxwing committed Apr 11, 2016
1 parent d161f3a commit 61fe406
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 46 deletions.
Expand Up @@ -182,7 +182,7 @@ class ContinuousQueryManager(sqlContext: SQLContext) {
val logicalPlan = df.logicalPlan.transform {
case StreamingRelation(dataSource, _, output) =>
// Materialize source to avoid creating it in every batch
val source = dataSource.createSource(Some(nextSourceId), Some(checkpointLocation))
val source = dataSource.createSource(nextSourceId, checkpointLocation)
nextSourceId += 1
// We still need to use the previous `output` instead of `source.schema` as attributes in
// "df.logicalPlan" has already used attributes of the previous `output`.
Expand Down
Expand Up @@ -123,28 +123,61 @@ case class DataSource(
}
}

def sourceSchema(): (String, StructType) = {
providingClass.newInstance() match {
case s: StreamSourceProvider =>
s.sourceSchema(sqlContext, userSpecifiedSchema, className, options)

case format: FileFormat =>
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
val path = caseInsensitiveOptions.getOrElse("path", {
throw new IllegalArgumentException("'path' is not specified")
})

val allPaths = caseInsensitiveOptions.get("path")
val globbedPaths = allPaths.toSeq.flatMap { path =>
val hdfsPath = new Path(path)
val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
SparkHadoopUtil.get.globPathIfNecessary(qualified)
}.toArray

val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths, None)
val dataSchema = userSpecifiedSchema.orElse {
format.inferSchema(
sqlContext,
caseInsensitiveOptions,
fileCatalog.allFiles())
}.getOrElse {
throw new AnalysisException("Unable to infer schema. It must be specified manually.")
}

(s"FileSource[$path]", dataSchema)
case _ =>
throw new UnsupportedOperationException(
s"Data source $className does not support streamed reading")
}
}

/**
* Returns a source that can be used to continually read data.
*
* Before running a real query (e.g., df.explain), `sourceId` and `checkpointLocation` is None
* as they are unknown. [[ContinuousQueryManager]] should set `sourceId` and `checkpointLocation`
* before starting a query.
*/
def createSource(
sourceId: Option[Long] = None,
checkpointLocation: Option[String] = None): Source = {
def createSource(sourceId: Long, checkpointLocation: String): Source = {
providingClass.newInstance() match {
case s: StreamSourceProvider =>
s.createSource(sqlContext, userSpecifiedSchema, className, options)
s.createSource(sqlContext, sourceId, userSpecifiedSchema, className, options)

case format: FileFormat =>
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
val path = caseInsensitiveOptions.getOrElse("path", {
throw new IllegalArgumentException("'path' is not specified")
})

val metadataPath =
sourceId.flatMap(id => checkpointLocation.map(location => s"$location/sources/$id"))
val metadataPath = s"$checkpointLocation/sources/$sourceId"

val allPaths = caseInsensitiveOptions.get("path")
val globbedPaths = allPaths.toSeq.flatMap { path =>
Expand Down Expand Up @@ -177,10 +210,8 @@ case class DataSource(
new CaseInsensitiveMap(options.filterKeys(_ != "path"))).resolveRelation()))
}

val source = new FileStreamSource(
sqlContext, path, Some(dataSchema), className, dataFrameBuilder)
metadataPath.foreach(source.setMetadataPath)
source
new FileStreamSource(
sqlContext, metadataPath, path, Some(dataSchema), className, dataFrameBuilder)
case _ =>
throw new UnsupportedOperationException(
s"Data source $className does not support streamed reading")
Expand Down
Expand Up @@ -33,15 +33,20 @@ import org.apache.spark.util.collection.OpenHashSet
*/
class FileStreamSource(
sqlContext: SQLContext,
metadataPath: String,
path: String,
dataSchema: Option[StructType],
providerName: String,
dataFrameBuilder: Array[String] => DataFrame) extends Source with Logging {

private val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration)
private var metadataLog: MetadataLog[Seq[String]] = _
private var maxBatchId: Long = _
private val metadataLog = new HDFSMetadataLog[Seq[String]](sqlContext, metadataPath)
private var maxBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L)

private val seenFiles = new OpenHashSet[String]
metadataLog.get(None, Some(maxBatchId)).foreach { case (batchId, files) =>
files.foreach(seenFiles.add)
}

/** Returns the schema of the data from this source */
override lazy val schema: StructType = {
Expand All @@ -61,34 +66,13 @@ class FileStreamSource(
}
}

/**
* Set the metadata path. This method should be called before using [[FileStreamSource]].
*/
def setMetadataPath(metadataPath: String): Unit = {
if (metadataLog != null) {
throw new IllegalStateException("metadataPath has already been set")
}
metadataLog = new HDFSMetadataLog[Seq[String]](sqlContext, metadataPath)
maxBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L)
metadataLog.get(None, Some(maxBatchId)).foreach { case (batchId, files) =>
files.foreach(seenFiles.add)
}
}

private def assertMetadataPath(): Unit = {
if (metadataLog == null) {
throw new IllegalStateException("metadataPath has not been set yet")
}
}

/**
* Returns the maximum offset that can be retrieved from the source.
*
* `synchronized` on this method is for solving race conditions in tests. In the normal usage,
* there is no race here, so the cost of `synchronized` should be rare.
*/
private def fetchMaxOffset(): LongOffset = synchronized {
assertMetadataPath()
val filesPresent = fetchAllFiles()
val newFiles = new ArrayBuffer[String]()
filesPresent.foreach { file =>
Expand Down Expand Up @@ -119,15 +103,13 @@ class FileStreamSource(

/** Return the latest offset in the source */
def currentOffset: LongOffset = synchronized {
assertMetadataPath()
new LongOffset(maxBatchId)
}

/**
* Returns the next batch of data that is available after `start`, if any is available.
*/
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
assertMetadataPath()
val startId = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L)
val endId = end.asInstanceOf[LongOffset].offset

Expand Down
Expand Up @@ -23,8 +23,8 @@ import org.apache.spark.sql.execution.datasources.DataSource

object StreamingRelation {
def apply(dataSource: DataSource): StreamingRelation = {
val source = dataSource.createSource()
StreamingRelation(dataSource, source.toString, source.schema.toAttributes)
val (name, schema) = dataSource.sourceSchema()
StreamingRelation(dataSource, name, schema.toAttributes)
}
}

Expand Down
Expand Up @@ -129,8 +129,16 @@ trait SchemaRelationProvider {
* Implemented by objects that can produce a streaming [[Source]] for a specific format or system.
*/
trait StreamSourceProvider {

def sourceSchema(
sqlContext: SQLContext,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): (String, StructType)

def createSource(
sqlContext: SQLContext,
sourceId: Long,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): Source
Expand Down
Expand Up @@ -21,6 +21,7 @@ import java.util.concurrent.TimeUnit

import scala.concurrent.duration._

import org.mockito.Mockito._
import org.scalatest.BeforeAndAfter

import org.apache.spark.sql._
Expand All @@ -31,22 +32,53 @@ import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.util.Utils

object LastOptions {

var mockStreamSourceProvider = mock(classOf[StreamSourceProvider])
var mockStreamSinkProvider = mock(classOf[StreamSinkProvider])
var sourceId: Long = -1
var parameters: Map[String, String] = null
var schema: Option[StructType] = null
var partitionColumns: Seq[String] = Nil

def clear(): Unit = {
sourceId = -1
parameters = null
schema = null
partitionColumns = null
reset(mockStreamSourceProvider)
reset(mockStreamSinkProvider)
}
}

/** Dummy provider: returns no-op source/sink and records options in [[LastOptions]]. */
class DefaultSource extends StreamSourceProvider with StreamSinkProvider {

private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

override def sourceSchema(
sqlContext: SQLContext,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): (String, StructType) = {
LastOptions.parameters = parameters
LastOptions.schema = schema
LastOptions.mockStreamSourceProvider.sourceSchema(sqlContext, schema, providerName, parameters)
("dummySource", fakeSchema)
}

override def createSource(
sqlContext: SQLContext,
sourceId: Long,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): Source = {
LastOptions.sourceId = sourceId
LastOptions.parameters = parameters
LastOptions.schema = schema
LastOptions.mockStreamSourceProvider.createSource(
sqlContext, sourceId, schema, providerName, parameters)
new Source {
override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil)
override def schema: StructType = fakeSchema

override def getOffset: Option[Offset] = Some(new LongOffset(0))

Expand All @@ -64,6 +96,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider {
partitionColumns: Seq[String]): Sink = {
LastOptions.parameters = parameters
LastOptions.partitionColumns = partitionColumns
LastOptions.mockStreamSinkProvider.createSink(sqlContext, parameters, partitionColumns)
new Sink {
override def addBatch(batchId: Long, data: DataFrame): Unit = {}
}
Expand Down Expand Up @@ -117,7 +150,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
assert(LastOptions.parameters("opt2") == "2")
assert(LastOptions.parameters("opt3") == "3")

LastOptions.parameters = null
LastOptions.clear()

df.write
.format("org.apache.spark.sql.streaming.test")
Expand Down Expand Up @@ -181,7 +214,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B

assert(LastOptions.parameters("path") == "/test")

LastOptions.parameters = null
LastOptions.clear()

df.write
.format("org.apache.spark.sql.streaming.test")
Expand All @@ -204,7 +237,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
assert(LastOptions.parameters("boolOpt") == "false")
assert(LastOptions.parameters("doubleOpt") == "6.7")

LastOptions.parameters = null
LastOptions.clear()
df.write
.format("org.apache.spark.sql.streaming.test")
.option("intOpt", 56)
Expand Down Expand Up @@ -303,4 +336,28 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B

assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(100000))
}

test("sourceId") {
LastOptions.clear()

val df1 = sqlContext.read
.format("org.apache.spark.sql.streaming.test")
.stream()

val df2 = sqlContext.read
.format("org.apache.spark.sql.streaming.test")
.stream()

val q = df1.union(df2).write
.format("org.apache.spark.sql.streaming.test")
.option("checkpointLocation", newMetadataDir)
.trigger(ProcessingTime(10.seconds))
.startStream()
q.stop()

verify(LastOptions.mockStreamSourceProvider)
.createSource(sqlContext, 0L, None, "org.apache.spark.sql.streaming.test", Map.empty)
verify(LastOptions.mockStreamSourceProvider)
.createSource(sqlContext, 1L, None, "org.apache.spark.sql.streaming.test", Map.empty)
}
}
Expand Up @@ -77,7 +77,7 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext {
.queryExecution.analyzed
.collect { case StreamingRelation(dataSource, _, _) =>
// There is only one source in our tests so just set sourceId to 1
dataSource.createSource(Some(0), Some(checkpointLocation)).asInstanceOf[FileStreamSource]
dataSource.createSource(0, checkpointLocation).asInstanceOf[FileStreamSource]
}.head
}

Expand All @@ -104,9 +104,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext
}
df.queryExecution.analyzed
.collect { case StreamingRelation(dataSource, _, _) =>
dataSource.createSource().asInstanceOf[FileStreamSource]
}.head
.schema
dataSource.sourceSchema()
}.head._2
}

test("FileStreamSource schema: no path") {
Expand Down
Expand Up @@ -115,8 +115,17 @@ class StreamSuite extends StreamTest with SharedSQLContext {
*/
class FakeDefaultSource extends StreamSourceProvider {

private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

override def sourceSchema(
sqlContext: SQLContext,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): (String, StructType) = ("fakeSource", fakeSchema)

override def createSource(
sqlContext: SQLContext,
sourceId: Long,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): Source = {
Expand Down

0 comments on commit 61fe406

Please sign in to comment.