From 8ca8bf48df72af3abab8715f319b822be8a50d6a Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 10 Feb 2016 15:07:56 -0800 Subject: [PATCH 1/2] Make FileStreamSource thread-safe and atomic --- .../streaming/FileStreamSource.scala | 58 ++++++++++--------- .../sql/execution/streaming/Source.scala | 3 + .../sql/streaming/FileStreamSourceSuite.scala | 2 +- 3 files changed, 35 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 14ba9f69bb1d7..7cee0c729f60f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.execution.streaming import java.io._ +import javax.annotation.concurrent.{GuardedBy, ThreadSafe} -import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.mutable.HashMap import scala.io.Codec import com.google.common.base.Charsets.UTF_8 @@ -35,6 +36,7 @@ import org.apache.spark.util.collection.OpenHashSet * * TODO Clean up the metadata files periodically */ +@ThreadSafe class FileStreamSource( sqlContext: SQLContext, metadataPath: String, @@ -44,10 +46,16 @@ class FileStreamSource( dataFrameBuilder: Array[String] => DataFrame) extends Source with Logging { private val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration) + + // The following mutable fields must be updated only in `advanceToNextBatch` so that if any + // exception happens in `getNextBatch`, we won't change any of these fields. + @GuardedBy("this") private var maxBatchId = -1 + @GuardedBy("this") private val seenFiles = new OpenHashSet[String] /** Map of batch id to files. This map is also stored in `metadataPath`. */ + @GuardedBy("this") private val batchToMetadata = new HashMap[Long, Seq[String]] { @@ -93,29 +101,14 @@ class FileStreamSource( /** * Returns the maximum offset that can be retrieved from the source. - * - * `synchronized` on this method is for solving race conditions in tests. In the normal usage, - * there is no race here, so the cost of `synchronized` should be rare. */ - private def fetchMaxOffset(): LongOffset = synchronized { + private def fetchNewFiles(): Seq[String] = { val filesPresent = fetchAllFiles() - val newFiles = new ArrayBuffer[String]() - filesPresent.foreach { file => - if (!seenFiles.contains(file)) { - logDebug(s"new file: $file") - newFiles.append(file) - seenFiles.add(file) - } else { - logDebug(s"old file: $file") - } - } - + val newFiles = filesPresent.filter(file => !seenFiles.contains(file)) if (newFiles.nonEmpty) { - maxBatchId += 1 - writeBatch(maxBatchId, newFiles) + writeBatch(maxBatchId + 1, newFiles) } - - new LongOffset(maxBatchId) + newFiles } /** @@ -134,24 +127,36 @@ class FileStreamSource( /** * Returns the next batch of data that is available after `start`, if any is available. */ - override def getNextBatch(start: Option[Offset]): Option[Batch] = { + override def getNextBatch(start: Option[Offset]): Option[Batch] = synchronized { val startId = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) - val end = fetchMaxOffset() - val endId = end.offset - + val newFiles = fetchNewFiles() + // Only increase the batch id when finding new files + val endId = if (newFiles.isEmpty) maxBatchId else maxBatchId + 1 if (startId + 1 <= endId) { val files = (startId + 1 to endId).filter(_ >= 0).flatMap { batchId => batchToMetadata.getOrElse(batchId, Nil) - }.toArray + }.toArray ++ newFiles logDebug(s"Return files from batches ${startId + 1}:$endId") logDebug(s"Streaming ${files.mkString(", ")}") - Some(new Batch(end, dataFrameBuilder(files))) + val batch = Some(new Batch(LongOffset(endId), dataFrameBuilder(files))) + if (newFiles.nonEmpty) { + // We should only update the status before returning so that if any exception happens in + // `getNextBatch`, no status will be changed. + advanceToNextBatch(maxBatchId + 1, files) + } + batch } else { None } } + private def advanceToNextBatch(id: Int, files: Seq[String]): Unit = { + batchToMetadata(id) = files + files.foreach(seenFiles.add) + maxBatchId = id + } + private def fetchAllBatchFiles(): Seq[FileStatus] = { try fs.listStatus(new Path(metadataPath)) catch { case _: java.io.FileNotFoundException => @@ -195,7 +200,6 @@ class FileStreamSource( } finally { writer.close() } - batchToMetadata(id) = files } /** Read the file names of the specified batch id from the metadata file */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala index 25922979ac83e..8735266113e2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import javax.annotation.concurrent.ThreadSafe + import org.apache.spark.sql.types.StructType /** @@ -24,6 +26,7 @@ import org.apache.spark.sql.types.StructType * monotonically increasing notion of progress that can be represented as an [[Offset]]. Spark * will regularly query each [[Source]] to see if any more data is available. */ +@ThreadSafe trait Source { /** Returns the schema of the data from this source */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 7a4ee0ef264d8..7c51dcab36491 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.Utils -class FileStreamSourceTest extends StreamTest with SharedSQLContext { +abstract class FileStreamSourceTest extends StreamTest with SharedSQLContext { import testImplicits._ From 26e9ad1c77d1c785c4baba8e722b02795fc8d71e Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 10 Feb 2016 21:18:47 -0800 Subject: [PATCH 2/2] Remove ThreadSafe and add document instead --- .../spark/sql/execution/streaming/FileStreamSource.scala | 7 ++++--- .../org/apache/spark/sql/execution/streaming/Source.scala | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 7cee0c729f60f..ea5c5e4865fcc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.streaming import java.io._ -import javax.annotation.concurrent.{GuardedBy, ThreadSafe} +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.HashMap import scala.io.Codec @@ -32,11 +32,12 @@ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.collection.OpenHashSet /** - * A very simple source that reads text files from the given directory as they appear. + * A very simple source that reads text files from the given directory as they appear. This class is + * thread-safe and can be called in multiple threads from different + * [[org.apache.spark.sql.ContinuousQuery ContinuousQuery]]. * * TODO Clean up the metadata files periodically */ -@ThreadSafe class FileStreamSource( sqlContext: SQLContext, metadataPath: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala index 8735266113e2a..d49494ee76777 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -17,16 +17,16 @@ package org.apache.spark.sql.execution.streaming -import javax.annotation.concurrent.ThreadSafe - import org.apache.spark.sql.types.StructType /** * A source of continually arriving data for a streaming query. A [[Source]] must have a * monotonically increasing notion of progress that can be represented as an [[Offset]]. Spark * will regularly query each [[Source]] to see if any more data is available. + * + * The implementation should be thread-safe because it can be called in multiple threads from + * different [[org.apache.spark.sql.ContinuousQuery ContinuousQuery]]. */ -@ThreadSafe trait Source { /** Returns the schema of the data from this source */