Skip to content

Commit

Permalink
fail fast and do not attempt to read very large files
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Apr 29, 2019
1 parent 20a3ef7 commit 11ff2cc
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
Expand Up @@ -24,11 +24,13 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, GlobFilter, Path}
import org.apache.hadoop.mapreduce.Job

import org.apache.spark.SparkException
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{And, DataSourceRegister, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Not, Or}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -99,6 +101,7 @@ class BinaryFileFormat extends FileFormat with DataSourceRegister {
val binaryFileSourceOptions = new BinaryFileSourceOptions(options)
val pathGlobPattern = binaryFileSourceOptions.pathGlobFilter
val filterFuncs = filters.map(filter => createFilterFunction(filter))
val maxLength = sparkSession.conf.get(TEST_BINARY_FILE_MAX_LENGTH)

file: PartitionedFile => {
val path = new Path(file.filePath)
Expand All @@ -115,6 +118,11 @@ class BinaryFileFormat extends FileFormat with DataSourceRegister {
case (MODIFICATION_TIME, i) =>
writer.write(i, DateTimeUtils.fromMillis(status.getModificationTime))
case (CONTENT, i) =>
if (status.getLen > maxLength) {
throw new SparkException(
s"The length of ${status.getPath} is ${status.getLen}, " +
s"which exceeds the max length allowed: ${maxLength}.")
}
val stream = fs.open(status.getPath)
try {
writer.write(i, ByteStreams.toByteArray(stream))
Expand Down Expand Up @@ -143,6 +151,16 @@ object BinaryFileFormat {
private[binaryfile] val CONTENT = "content"
private[binaryfile] val BINARY_FILE = "binaryFile"

private[binaryfile]
val CONF_TEST_BINARY_FILE_MAX_LENGTH = "spark.test.data.source.binaryFile.maxLength"
/** An internal conf for testing max length. */
private[binaryfile] val TEST_BINARY_FILE_MAX_LENGTH = SQLConf
.buildConf(CONF_TEST_BINARY_FILE_MAX_LENGTH)
.internal()
.intConf
// The theoretical max length is Int.MaxValue, though VMs might implement a smaller max.
.createWithDefault(Int.MaxValue)

/**
* Schema for the binary file data source.
*
Expand Down
Expand Up @@ -27,6 +27,7 @@ import com.google.common.io.{ByteStreams, Closeables}
import org.apache.hadoop.fs.{FileStatus, FileSystem, GlobFilter, Path}
import org.mockito.Mockito.{mock, when}

import org.apache.spark.SparkException
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.execution.datasources.PartitionedFile
Expand Down Expand Up @@ -339,4 +340,33 @@ class BinaryFileFormatSuite extends QueryTest with SharedSQLContext with SQLTest
assert(df.select("LENGTH").first().getLong(0) === content.length,
"column pruning should be case insensitive")
}

test("fail fast and do not attempt to read if a file is too big") {
assert(spark.conf.get(TEST_BINARY_FILE_MAX_LENGTH) === Int.MaxValue)
withTempPath { file =>
val path = file.getPath
val content = "123".getBytes
Files.write(file.toPath, content, StandardOpenOption.CREATE, StandardOpenOption.WRITE)
def readContent(maxLength: Int): Array[Byte] = {
try {
spark.conf.set(CONF_TEST_BINARY_FILE_MAX_LENGTH, maxLength)
spark.read.format(BINARY_FILE)
.load(path)
.select(CONTENT)
.first()
.getAs[Array[Byte]](0)
} finally {
spark.conf.unset(CONF_TEST_BINARY_FILE_MAX_LENGTH)
}
}
assert(readContent(Int.MaxValue) === content)
assert(readContent(content.length) === content)
// Disable read. If the implementation attempts to read, the exception would be different.
file.setReadable(false)
val caught = intercept[SparkException] {
readContent(content.length - 1)
}
assert(caught.getMessage.contains("exceeds the max length allowed"))
}
}
}

0 comments on commit 11ff2cc

Please sign in to comment.