In [None]:
import logging
from pyspark.sql import SparkSession

In [None]:
# Create SparkSession
spark = SparkSession.builder\
             .master("local[1]")\
             .appName("spark")\
             .getOrCreate()

In [None]:
# Let's load the data first
df_budgets = spark.read. \
                option("multiline",
                "true").json('datasets/json/department_budgets.json')

In [None]:
# Show schema
df_budgets.printSchema()

In [None]:
# We can still query the data, using Json paths: 
df_budgets.select('offices').where('department_id == 1').show(truncate=False)

## Flattening JSON into Columnar format is normally easier, cleaner and more scalable. 
- Suggestion: always test and benchmark performance, to compare Json Paths access vs. flattening

In [None]:
# Spark function to flatten nested structs. Function adapted from GitHub: https://bit.ly/43ZegOL Spark dataframe with semi-structured types, 
# such as StructType or ArrayType return Spark dataframe
def flatten_dataframe(df):
    try:
        # compute Complex Fields (Lists and Structs) in Schema   
        complex_fields = dict([(field.name, field.dataType)
                                    for field in df.schema.fields
                                    if type(field.dataType) == ArrayType or  type(field.dataType) == StructType])
            
        while len(complex_fields)!=0:
            col_name=list(complex_fields.keys())[0]
            
            # if StructType then convert all sub element to columns.
            # i.e. flatten structs
            if (type(complex_fields[col_name]) == StructType):
                expanded = [col(col_name+'.'+k).alias(col_name+'_'+k) for k in [ n.name for n in  complex_fields[col_name]]]
                df=df.select("*", *expanded).drop(col_name)
            
            # if ArrayType then add the Array Elements as Rows using the explode function
            # i.e. explode Arrays
            elif (type(complex_fields[col_name]) == ArrayType):    
                df=df.withColumn(col_name, explode_outer(col_name))
            
            # recompute remaining Complex Fields in Schema       
            complex_fields = dict([(field.name, field.dataType)
                                    for field in df.schema.fields
                                    if type(field.dataType) == ArrayType or  type(field.dataType) == StructType])
        return df
    
    except Exception as e:
        logging.error('Error while flattening JSON data: {}'.format(e))

In [None]:
# Let's now flatten the data.
df_budgets_flat = flatten_dataframe(df_budgets) 

In [None]:
# Show sample
df_budgets_flat.show()

In [None]:
# New flatten schema
df_budgets_flat.printSchema()

In [None]:
# Same process: register JSON DF as Temporary View
df_budgets_flat.createOrReplaceTempView('budgets_flat')

In [None]:
spark.sql('''
          select * from budgets_flat)
          ''').show(n=50)