##### Option 1. Recursive. Cars data

In [0]:
%scala
import org.apache.spark.sql.SparkSession
import org.apache.spark.SparkContext
import org.apache.spark.SparkConf
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types.{StructType,ArrayType}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.explode_outer

In [0]:
%scala
// val conf = new SparkConf().setMaster("local[*]").setAppName("JSON Flattener")
// val sc = new SparkContext(conf)

// val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

val inputJson = """|{
                    | "name":"John",
                    | "age":30,
                    | "bike":{
                    |	"name":"Bajaj", "models":["Dominor", "Pulsar"]
                    |	},
                    | "cars": [
                    |   { "name":"Ford", "models":[ "Fiesta", "Focus", "Mustang" ] },
                    |   { "name":"BMW", "models":[ "320", "X3", "X5" ] },
                    |   { "name":"Fiat", "models":[ "500", "Panda" ] }
                    | ]
                    |}""".stripMargin('|')
println(inputJson)
  
  //creating rdd for the json
  val jsonRDD = sc.parallelize(inputJson::Nil)
  //creating DF for the json
  // val jsonDF = sqlContext.read.json(jsonRDD)
  val jsonDF = spark.read.json(spark.createDataset(inputJson :: Nil))
 
  //Schema of the JSON DataFrame before Flattening
  jsonDF.schema
  
  //Output DataFrame Before Flattening
  jsonDF.show(false)
  
  //Function for exploding Array and StructType column
  
  def flattenDataframe(df: DataFrame): DataFrame = {

    val fields = df.schema.fields
    val fieldNames = fields.map(x => x.name)
    val length = fields.length
    
    for(i <- 0 to fields.length-1){
      val field = fields(i)
      val fieldtype = field.dataType
      val fieldName = field.name
      fieldtype match {
        case arrayType: ArrayType =>
          val fieldNamesExcludingArray = fieldNames.filter(_!=fieldName)
          val fieldNamesAndExplode = fieldNamesExcludingArray ++ Array(s"explode_outer($fieldName) as $fieldName")
         // val fieldNamesToSelect = (fieldNamesExcludingArray ++ Array(s"$fieldName.*"))
          val explodedDf = df.selectExpr(fieldNamesAndExplode:_*)
          return flattenDataframe(explodedDf)
        case structType: StructType =>
          val childFieldnames = structType.fieldNames.map(childname => fieldName +"."+childname)
          val newfieldNames = fieldNames.filter(_!= fieldName) ++ childFieldnames
          val renamedcols = newfieldNames.map(x => (col(x.toString()).as(x.toString().replace(".", "_"))))
         val explodedf = df.select(renamedcols:_*)
          return flattenDataframe(explodedf)
        case _ =>
      }
    }
    df
  }

In [0]:
%scala
val flattendedJSON = flattenDataframe(jsonDF)
  //schema of the JSON after Flattening
  flattendedJSON.schema
  
  //Output DataFrame After Flattening
  flattendedJSON.show(false)

##### Option 1. Recursive. Donuts data

In [0]:
%scala
val donut_json ="""

{
    "id": "0001",
    "type": "donut",
    "name": "Cake",
    "ppu": 0.55,
    "batters":
        {
            "batter":
                [
                    { "id": "1001", "type": "Regular" },
                    { "id": "1002", "type": "Chocolate" },
                    { "id": "1003", "type": "Blueberry" },
                    { "id": "1004", "type": "Devil's Food" }
                ]
        },
    "topping":
        [
            { "id": "5001", "type": "None" },
            { "id": "5002", "type": "Glazed" },
            { "id": "5005", "type": "Sugar" },
            { "id": "5007", "type": "Powdered Sugar" },
            { "id": "5006", "type": "Chocolate with Sprinkles" },
            { "id": "5003", "type": "Chocolate" },
            { "id": "5004", "type": "Maple" }
        ]
}
"""

In [0]:
%scala
 import sqlContext.implicits._
  
  
  //creating rdd for the json
  val jsonRDD = sc.parallelize(donut_json::Nil)
  //creating DF for the json
  // val jsonDF = sqlContext.read.json(jsonRDD)
  val jsonDF = spark.read.json(spark.createDataset(donut_json :: Nil))
 
  //Schema of the JSON DataFrame before Flattening
  jsonDF.schema
  
  //Output DataFrame Before Flattening
  jsonDF.show(false)
  
  //Function for exploding Array and StructType column
  
  def flattenDataframe(df: DataFrame): DataFrame = {

    val fields = df.schema.fields
    val fieldNames = fields.map(x => x.name)
    val length = fields.length
    
    for(i <- 0 to fields.length-1){
      val field = fields(i)
      val fieldtype = field.dataType
      val fieldName = field.name
      fieldtype match {
        case arrayType: ArrayType =>
          val fieldNamesExcludingArray = fieldNames.filter(_!=fieldName)
          val fieldNamesAndExplode = fieldNamesExcludingArray ++ Array(s"explode_outer($fieldName) as $fieldName")
         // val fieldNamesToSelect = (fieldNamesExcludingArray ++ Array(s"$fieldName.*"))
          val explodedDf = df.selectExpr(fieldNamesAndExplode:_*)
          return flattenDataframe(explodedDf)
        case structType: StructType =>
          val childFieldnames = structType.fieldNames.map(childname => fieldName +"."+childname)
          val newfieldNames = fieldNames.filter(_!= fieldName) ++ childFieldnames
          val renamedcols = newfieldNames.map(x => (col(x.toString()).as(x.toString().replace(".", "_"))))
         val explodedf = df.select(renamedcols:_*)
          return flattenDataframe(explodedf)
        case _ =>
      }
    }
    df
  }
  

In [0]:
%scala
val flattendedJSON = flattenDataframe(jsonDF)
  //schema of the JSON after Flattening
  flattendedJSON.schema
  
  //Output DataFrame After Flattening
  flattendedJSON.show(false)

##### Option 2. Non-recursive. Donuts data

In [0]:
%scala
display(dbutils.fs.ls("/FileStore/donuts"))

path,name,size,modificationTime
dbfs:/FileStore/donuts/donut1.json,donut1.json,1092,1706208948000


In [0]:
%python
df2 = spark.read.option("multiline", "true").json("dbfs:/FileStore/donuts/donut1.json")
display(df2)

batters,id,name,ppu,topping,type
"List(List(List(1001, Regular), List(1002, Chocolate), List(1003, Blueberry), List(1004, Devil's Food)))",1,Cake,0.55,"List(List(5001, None), List(5002, Glazed), List(5005, Sugar), List(5007, Powdered Sugar), List(5006, Chocolate with Sprinkles), List(5003, Chocolate), List(5004, Maple))",donut


In [0]:
%python
from pyspark.sql.types import *
from pyspark.sql.functions import *

def flatten(df):
    # compute Complex Fields (Lists and Structs) in Schema   
   complex_fields = dict([(field.name, field.dataType)
                             for field in df.schema.fields
                             if type(field.dataType) == ArrayType or  type(field.dataType) == StructType])
   while len(complex_fields)!=0:
      col_name=list(complex_fields.keys())[0]
      print ("Processing :"+col_name+" Type : "+str(type(complex_fields[col_name])))
    
    #    if StructType then convert all sub element to columns.
    #    i.e. flatten structs
      if (type(complex_fields[col_name]) == StructType):
         expanded = [col(col_name+'.'+k).alias(col_name+'_'+k) for k in [ n.name for n in  complex_fields[col_name]]]
         df=df.select("*", *expanded).drop(col_name)
    
    #    if ArrayType then add the Array Elements as Rows using the explode function
    #    i.e. explode Arrays
      elif (type(complex_fields[col_name]) == ArrayType):    
         df=df.withColumn(col_name,explode_outer(col_name))
    
    #    recompute remaining Complex Fields in Schema       
      complex_fields = dict([(field.name, field.dataType)
                             for field in df.schema.fields
                             if type(field.dataType) == ArrayType or  type(field.dataType) == StructType])
   return df

In [0]:
%python
df2_flatten = flatten(df2)
display(df2_flatten)

Processing :batters Type : <class 'pyspark.sql.types.StructType'>
Processing :topping Type : <class 'pyspark.sql.types.ArrayType'>
Processing :topping Type : <class 'pyspark.sql.types.StructType'>
Processing :batters_batter Type : <class 'pyspark.sql.types.ArrayType'>
Processing :batters_batter Type : <class 'pyspark.sql.types.StructType'>


id,name,ppu,type,topping_id,topping_type,batters_batter_id,batters_batter_type
1,Cake,0.55,donut,5001,,1001,Regular
1,Cake,0.55,donut,5001,,1002,Chocolate
1,Cake,0.55,donut,5001,,1003,Blueberry
1,Cake,0.55,donut,5001,,1004,Devil's Food
1,Cake,0.55,donut,5002,Glazed,1001,Regular
1,Cake,0.55,donut,5002,Glazed,1002,Chocolate
1,Cake,0.55,donut,5002,Glazed,1003,Blueberry
1,Cake,0.55,donut,5002,Glazed,1004,Devil's Food
1,Cake,0.55,donut,5005,Sugar,1001,Regular
1,Cake,0.55,donut,5005,Sugar,1002,Chocolate


##### Option 3. Recursive with `json4s`

In [0]:
%scala
import org.json4s._
import org.json4s.native.JsonMethods._

object JsonFlattener {
  def flatten(jsonString: String): Map[String, JValue] = {
    val json = parse(jsonString)
    flatten(json)
  }

  def flatten(json: JValue): Map[String, JValue] = {
    json match {
      case JObject(fields) =>
        fields.flatMap {
          case (key, value) =>
            flatten(value).map {
              case (nestedKey, nestedValue) =>
                s"$key.$nestedKey" -> nestedValue
            }
        }.toMap

      case JArray(elements) =>
        elements.zipWithIndex.flatMap {
          case (value, index) =>
            flatten(value).map {
              case (nestedKey, nestedValue) =>
                s"$index.$nestedKey" -> nestedValue
            }
        }.toMap

      case JNothing | JNull =>
        Map.empty

      case other =>
        Map("" -> other)
    }
  }
}

In [0]:
%scala
val jsonString = """{"a": {"b": 1, "c": 2}, "d": [3, 4]}"""
    val flattenedMap = flatten(jsonString)
    flattenedMap.foreach(println)