In [None]:
from array import *
import time

from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql import DataFrame
from pyspark.sql import SparkSession
from functools import reduce



In [None]:
def rename_cols(df: DataFrame, col_names: Dict[str,str]) ->DataFrame:
    return df.select(*[col(col_name).alias(col_names.get(col_name, col_name)) for col_name in df.columns])

def update_column_names(df:DataFrame, index: int) -> DataFrame:
    df_temp = df
    all_cols = df_temp.columns
    new_cols = dict((column, f"{column}*{index}") for column in all_cols)
    df_tmp = df_temp.transform(lambda df_x: rename_cols(df_x, new_cols))
    #print(all_cols)
    #print(new_cols)
    #display(df_temp)
    return df_temp

def final_column_names(df:DataFrame, stringToDelete) -> DataFrame:
    df_temp = df
    all_cols = df_temp.columns
    new_cols = dict((column, column.replace(stringToDelete, '').lower()) for column in all_cols)
    df_tmp = df_temp.transform(lambda df_x: rename_cols(df_x, new_cols))
    return df_tmp

In [None]:
#The function will flaten the JSON, but it CANNOT be applied to the full json because there are atributes with the same name
def flatten_json(df_arg: DataFrame, index: int = 1) -> DataFrame:
    
    #display(df_arg)
    df = update_column_names(df_arg, index) if index==1 else df_arg
    #display(df)

    #get structure of json
    fields = df.schema.fields
    #print(fields)

    #For each atribute in the JSON
    for field in fields:
        #get the data type and column name
        data_type = str(field.dataType)
        column_name = field.name

        #The first 10 chars will show if the current section has nested atributes (if it is ArrayType)
        first_10_chars = data_type[0:10]

        #If it has nested atributies it will explode another level and run the function again
        if first_10_chars == 'ArrayType(':
            df_temp = df.withColumn(column_name, explode_outer(col(column_name)))
            return flatten_json(df_temp, index + 1)

        #otherwise 
        elif first_10_chars == 'StructType':
            current_col = column_name
            append_str = current_col

            #all data types of the atributes of that level:
            data_type_str = str(df.schema[current_col].dataType)

            df_temp = df.withColumnRenamed(column_name, column_name + "#1") \
                if column_name in data_type_str else df
            current_col = current_col + "#1" if column_name in data_type_str else current_col

            df_before_expanding = df_temp.select(f"{current_col}.*")
            newly_gen_cols = df_before_expanding.columns
            
            begin_index = append_str.rfind('*')
            end_index = len(append_str)

            custom_cols = dict((field, f"{append_str}_{field}") for field in newly_gen_cols)
            df_temp2 = df_temp.select("*", f"{current_col}.*").drop(current_col)
            df_temp3 = df_temp2.transform(lambda df_x: rename_cols(df_x, custom_cols))
            return flatten_json(df_temp3, index + 1)
    return df