From 65b44fa3219538265055285cee072267af0a7b85 Mon Sep 17 00:00:00 2001 From: Luca Rosellini Date: Mon, 1 Jun 2015 14:22:38 +0200 Subject: [PATCH] [SPARK-4782] Added generic inferSchema method to allow schema infer for an RDD of a generic type T for which the user provides a mapping function from RDD[T] => RDD[Map[String,Any]] --- .../scala/org/apache/spark/sql/json/JsonRDD.scala | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 95eb1174b1dd6..372d7bd8b2dc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -44,15 +44,24 @@ private[sql] object JsonRDD extends Logging { private[sql] def inferSchema( json: RDD[String], - samplingRatio: Double = 1.0, + samplingRatio: Double, columnNameOfCorruptRecords: String): StructType = { + + inferSchema(json, samplingRatio, columnNameOfCorruptRecords, parseJson) + } + + private[sql] def inferSchema[T <: AnyRef]( + json: RDD[T], + samplingRatio: Double = 1.0, + columnNameOfCorruptRecords: String, + parseData: (RDD[T], String) => RDD[Map[String, Any]]): StructType = { require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) val allKeys = if (schemaData.isEmpty()) { Set.empty[(String, DataType)] } else { - parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _) + parseData(json,columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _) } createSchema(allKeys) }