In [None]:
from pyspark.sql.types import *
import requests
import pyspark.sql.functions as f
from pyspark.sql import SparkSession
import json

spark = SparkSession.builder.appName("TestSparkApp")\
        .config("spark.jars.packages", "io.delta:delta-core_2.12:0.7.0")\
        .appName('MDESparkApp')\
        .getOrCreate()

class EdFiSchemaGenerator:
    """
    A class to Read Metadata file and Generate Ed-Fi Schema in Delta format.
    Attributes:
    -----------
        metadata_path: Path to the Metadata file.
        out_path: Path to store Output Delta tables in Ed-Fi schema.
        swagger_url: URL to the Ed-Fi Open API Swagger endpoint.
    """

    def __init__(self, metadata_path, out_path, swagger_url):
        self.metadata_path = metadata_path
        self.out_path = out_path
        self.metadata_values = spark.read.csv(self.metadata_path, header=True).collect()

        self.tables = list(set([x.table_name for x in self.metadata_values]))
        self.yet_to_process_tables = [x for x in self.tables]
        self.swagger_json = json.loads(requests.get(swagger_url).text)
        # self.swagger_json = None
        self.processed_tables = []
        self.dependency_dict = {}
        self.schemas = {}
        self.primitive_tables = []

    def get_reference(self, row):
        flattened_fields = []
        if(row.type == 'array'):
            reference = (row.items).split('/')[-1][:-1]
        elif(row['$ref'] != None):
            if('::' in row['$ref']):
                reference = row['$ref'].split('::')[0]
                flattened_fields = (row['$ref'].split('::')[1]).split(':')
            else:
                reference = row['$ref']
        else:
            return None, None
        reference = reference.split('/')[-1]
        return reference, flattened_fields

    def get_data_type(self, dtype, format):
        if(dtype == 'string'):
            if(format == 'date'):
                return DateType()
            if(format == 'date-time'):
                return TimestampType()
            return StringType()
        if(dtype == 'integer'):
            return IntegerType()
        if(dtype == 'number'):
            return DecimalType()
        if(dtype == 'boolean'):
            return BooleanType()

    def create_primitive_schemas(self):
        print("Creating primitive schemas.")
        for table in self.tables:
            table_df = [x for x in self.metadata_values if x.table_name == table]
            isPrimitive = not(any(x for x in table_df if x['$ref'] != None)) and not(any(x for x in table_df if x.type == 'array'))
            if(isPrimitive):
                self.schemas[table] = StructType([StructField(row.column_name, self.get_data_type(row.type, row.format), not(row['x-Ed-Fi-isIdentity'])) for row in table_df])
                self.primitive_tables.append(table)
                self.processed_tables.append(table)
                self.yet_to_process_tables.remove(table)
        print("Completed creating primitive schemas.")

    def create_dependency_dict(self):
        referenced_df_values = [x for x in self.metadata_values if x.type == 'array' or x['$ref'] != None]
        # referenced_df_values = self.metadata_df.filter((f.col('type') == 'array') | (f.col('$ref').isNotNull())).collect()
        for row in referenced_df_values:
            reference, flatten_fields = self.get_reference(row)
            if(row.table_name in self.dependency_dict.keys() and reference not in self.dependency_dict[row.table_name]):
                self.dependency_dict[row.table_name].append(reference)
            elif(row not in self.dependency_dict):
                self.dependency_dict[row.table_name] = [reference]

    def create_schemas_for_complex_tables(self):
        print("Creating complex schemas.")
        while len(self.yet_to_process_tables) > 0:
            for entity in self.yet_to_process_tables:
                if(len([x for x in self.dependency_dict[entity] if x not in self.processed_tables]) == 0):
                    table_schema = [x for x in self.metadata_values if x.table_name == entity]
                    spark_schema = []
                    for row in table_schema:
                        reference, flatten_fields = self.get_reference(row)
                        if(row.type == 'array'):
                            # Handle Array Objects
                            column_schema = StructField(row.column_name, ArrayType(self.schemas[reference]), True)
                        elif(row['$ref'] != None):
                            # Handle Normal Objects
                            if(len(flatten_fields)>0):
                                spark_schema += [field for field in self.schemas[reference].fields if field.name in flatten_fields]
                                continue
                            else:
                                column_schema = StructField(row.column_name, self.schemas[reference], True)
                        else:
                            # Primitive Data type
                            column_schema = StructField(row.column_name, self.get_data_type(row.type, row.format))
                        spark_schema.append(column_schema)
                    self.schemas[entity] = StructType(spark_schema)
                    self.processed_tables.append(entity)
                    self.yet_to_process_tables.remove(entity)
                    print(f'created schema for {entity}')
        print("Completed creating complex schemas.")

    def write_empty_delta_files(self):
        # with open("C:/Users/agundapaneni/Downloads/swagger_edfi.json") as f:
            # self.swagger_json = json.load(f)
        for path in self.swagger_json['paths']:
            if(path.count('/') == 2):
                response = self.swagger_json['paths'][path]['get']['responses']['200']['schema']
                reference = response['items']['$ref']
                reference = reference.split('_')[-1]
                emptyDF = spark.createDataFrame(spark.sparkContext.emptyRDD(), self.schemas[reference]).withColumn('LastModifiedDate', f.lit(None).cast(TimestampType()))
                emptyDF.write.format('delta').mode('overwrite').save(f"{self.out_path}/{path.split('/')[-1]}")

    def create_schemas(self):
        self.create_primitive_schemas()
        self.create_dependency_dict()
        self.create_schemas_for_complex_tables()
        self.write_empty_delta_files()
