Skip to content

Commit

Permalink
Addressing Hyukjin Kwon's review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxGekk committed Apr 22, 2018
1 parent fdeac84 commit 1b86df3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
Expand Up @@ -2128,16 +2128,6 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
}

val sampledTestData = (value: java.lang.Long) => {
val predefinedSample = Set[Long](2, 8, 15, 27, 30, 34, 35, 37, 44, 46,
57, 62, 68, 72)
if (predefinedSample.contains(value)) {
s"""{"f1":${value.toString}}"""
} else {
s"""{"f1":${(value.toDouble + 0.1).toString}}"""
}
}

test("SPARK-23849: schema inferring touches less data if samplingRatio < 1.0") {
// Set default values for the DataSource parameters to make sure
// that whole test file is mapped to only one partition. This will guarantee
Expand All @@ -2146,7 +2136,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
"spark.sql.files.maxPartitionBytes" -> (128 * 1024 * 1024).toString,
"spark.sql.files.openCostInBytes" -> (4 * 1024 * 1024).toString
)(withTempPath { path =>
val ds = spark.range(0, 100, 1, 1).map(sampledTestData)
val ds = sampledTestData.coalesce(1)
ds.write.text(path.getAbsolutePath)
val readback = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath)

Expand All @@ -2155,7 +2145,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}

test("SPARK-23849: usage of samplingRatio while parsing a dataset of strings") {
val ds = spark.range(0, 100, 1, 1).map(sampledTestData)
val ds = sampledTestData.coalesce(1)
val readback = spark.read.option("samplingRatio", 0.1).json(ds)

assert(readback.schema == new StructType().add("f1", LongType))
Expand All @@ -2180,11 +2170,11 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {

test("SPARK-23849: sampling files for schema inferring in the multiLine mode") {
withTempDir { dir =>
Files.write(Paths.get(dir.getAbsolutePath, "0.json"), """{"a":"a"}""".getBytes,
StandardOpenOption.CREATE_NEW)
Files.write(Paths.get(dir.getAbsolutePath, "0.json"),
"""{"a":"a"}""".getBytes(StandardCharsets.UTF_8), StandardOpenOption.CREATE_NEW)
for (i <- 1 until 10) {
Files.write(Paths.get(dir.getAbsolutePath, s"$i.json"), s"""{"a":$i}""".getBytes,
StandardOpenOption.CREATE_NEW)
Files.write(Paths.get(dir.getAbsolutePath, s"$i.json"),
s"""{"a":$i}""".getBytes(StandardCharsets.UTF_8), StandardOpenOption.CREATE_NEW)
}
val files = (0 until 10).map { i =>
val hadoopConf = spark.sessionState.newHadoopConf()
Expand Down
Expand Up @@ -233,4 +233,16 @@ private[json] trait TestJsonData {
spark.createDataset(spark.sparkContext.parallelize("""{"a":123}""" :: Nil))(Encoders.STRING)

def empty: Dataset[String] = spark.emptyDataset(Encoders.STRING)

def sampledTestData: Dataset[String] = {
spark.range(0, 100, 1).map { index =>
val predefinedSample = Set[Long](2, 8, 15, 27, 30, 34, 35, 37, 44, 46,
57, 62, 68, 72)
if (predefinedSample.contains(index)) {
s"""{"f1":${index.toString}}"""
} else {
s"""{"f1":${(index.toDouble + 0.1).toString}}"""
}
}(Encoders.STRING)
}
}

0 comments on commit 1b86df3

Please sign in to comment.