In [None]:
from curses import meta
from pyspark.sql.types import *
import requests
import pyspark.sql.functions as f
import json
import re

class OpenAPIUtil:
    """
    A Utility class to help processing transformations using Open API (Swagger).

    Parameters:
        1) swagger_url: URL to the OpenAPI Swaggern endpoint

    Methods:
        1) create_spark_schemas(): returns a dictionary of Spark schemas of all endpoints with entity name as the Key.
        2) create_metadata(): returns list of dictionaries containing metadata of each field in every endpoint.
        3) write_oea_metadata(destination_path): Writes out the OEA Metadata CSV file at the destination directory.
    """

    def __init__(self, swagger_url):
        self.swagger_json = json.loads(requests.get(swagger_url).text)
        self.metadata_headers = ['table_name', 'column_name', 'type', 'format', 'maxLength', 'required', 'items', '$ref', 'pseudonymization']
        self.definitions = {}
        self.metadata = {}
        self.tables = []
        self.schemas = {}
        self.dependency_dict = {}
        self.dependency_order = []
        self.visited = {}

    def get_reference(self, row):
        if(row['type'] == 'array'):
            reference = row['items']['$ref']
        elif(row['$ref'] != None):
           reference = row['$ref']
        else:
            return None
        return reference.split('/')[-1].split('_')[-1]

    def pluralize(self, noun):
        if noun == 'person' : return 'people'
        if noun == 'survey' : return 'surveys'
        if re.search('[sxz]$', noun):
            return re.sub('$', 'es', noun)
        if re.search('y$', noun):
            return re.sub('y$', 'ies', noun)
        return noun + 's'

    def get_data_type(self, dtype, format):
        if(dtype == 'string'):
            if(format == 'date'):
                return DateType()
            if(format == 'date-time'):
                return TimestampType()
            return StringType()
        if(dtype == 'integer'):
            return IntegerType()
        if(dtype == 'number'):
            return DecimalType()
        if(dtype == 'boolean'):
            return BooleanType()

    def create_definitions(self):
        for entity in self.swagger_json['definitions']:
            properties = self.swagger_json['definitions'][entity]['properties']
            table_name = entity.split('_')[-1]
            table_schema = {}

            for prop in properties:
                if 'description' in properties[prop].keys():
                    properties[prop].pop('description')
                field_info = properties[prop]
                if 'required' in self.swagger_json['definitions'][entity].keys():
                    field_info['required'] = True if prop in self.swagger_json['definitions'][entity]['required'] else False
                else:
                    field_info['required'] = False
                field_info['table_name'] = entity.split('_')[-1]
                field_info['column_name'] = prop
                if 'x-Ed-Fi-pseudonymization' in field_info:
                    field_info['pseudonymization'] = field_info['x-Ed-Fi-pseudonymization']
                    field_info.pop('x-Ed-Fi-pseudonymization')
                for header in [x for x in self.metadata_headers if x not in field_info] : field_info[header] = None
                table_schema[prop] = field_info

            self.definitions[table_name] = table_schema
        self.tables = [x for x in self.definitions.keys()]

    def create_metadata(self):
        if(len(self.schemas) == 0):
            self.create_spark_schemas()
        for table_name in self.dependency_order:
            table_metadata = []
            for col_name in self.definitions[table_name]:
                col_schema = self.definitions[table_name][col_name]
                key = self.pluralize(table_name)

                if 'x-Ed-Fi-fields-to-pluck' in col_schema and col_schema['x-Ed-Fi-fields-to-pluck'] != ["*"]:
                    referenced_table = self.get_reference(col_schema)
                    table_metadata += [x for x in self.metadata[self.pluralize(referenced_table)] if x[0] in col_schema['x-Ed-Fi-fields-to-pluck']]

                elif 'x-Ed-Fi-explode' in col_schema and col_schema['x-Ed-Fi-explode']:
                    referenced_table = self.get_reference(col_schema)
                    table_metadata += self.metadata[self.pluralize(referenced_table)]

                else:
                    op = self.definitions[table_name][col_name]['pseudonymization']
                    if op == None: op = 'no-op'
                    table_metadata.append([col_name, self.schemas[key][col_name].dataType.typeName(), op])
            self.metadata[self.pluralize(table_name)] = table_metadata
        return self.metadata

    def write_oea_metadata(self, destination_path):
        if(self.metadata == []):
            self.create_metadata()
        oea_metadata = []
        for table_name in self.metadata:
            oea_metadata.append([table_name, None, None, None])
            for col_metadata in self.metadata[table_name]:
                oea_metadata.append([None] + col_metadata)
        metadata_df = spark.createDataFrame(oea_metadata, ['Entity Name','Attribute Name','Attribute Data Type','Pseudonymization'])
        metadata_df.write.format('csv').save(destination_path)

    def create_dependency_dict(self):
        for table_name in self.definitions:
            for column_name in self.definitions[table_name]:
                column_info = self.definitions[table_name][column_name]
                referenced_table = self.get_reference(column_info)
                if(referenced_table is None):
                    continue
                if table_name not in self.dependency_dict:
                    self.dependency_dict[table_name] = [referenced_table]
                elif referenced_table not in self.dependency_dict[table_name]:
                    self.dependency_dict[table_name].append(referenced_table)

    def dfs(self, table_name):
        self.visited[table_name] = True
        if table_name not in self.dependency_dict:
            self.dependency_order.append(table_name)
            return

        for dependent_table in self.dependency_dict[table_name]:
            if(self.visited[dependent_table] is False):
                self.dfs(dependent_table)
            if(self.visited[dependent_table] is False):
                self.dependency_order.append(dependent_table)

        self.dependency_order.append(table_name)

    def create_dependency_order(self):
        for table_name in self.tables:
            self.visited[table_name] = False
        for table_name in self.tables:
            if(self.visited[table_name] is False):
                self.dfs(table_name)

    def create_spark_schemas_from_definitions(self):
        for entity in self.dependency_order:
            table_schema = self.definitions[entity]
            spark_schema = []
            for col_name in table_schema:
                col_metadata = {}
                if('pseudonymization' in table_schema[col_name]): col_metadata['pseudonymization'] = table_schema[col_name]['pseudonymization']
                referenced_table = self.get_reference(table_schema[col_name])
                if table_schema[col_name]['type'] == 'array':
                    datatype = ArrayType(self.schemas[self.pluralize(referenced_table)])
                    col_metadata['x-Ed-Fi-explode'] = table_schema[col_name]['x-Ed-Fi-explode']
                elif table_schema[col_name]['$ref'] != None:
                    datatype = self.schemas[self.pluralize(referenced_table)]
                    if('x-Ed-Fi-fields-to-pluck' in table_schema[col_name]):
                        col_metadata['x-Ed-Fi-fields-to-pluck'] = table_schema[col_name]['x-Ed-Fi-fields-to-pluck']
                else:
                    datatype = self.get_data_type(table_schema[col_name]['type'], table_schema[col_name]['format'])
                col_spark_schema = StructField(col_name, datatype, not(table_schema[col_name]['required']))
                col_spark_schema.metadata = col_metadata
                spark_schema.append(col_spark_schema)
            self.schemas[self.pluralize(entity)] = StructType(spark_schema)

    def create_spark_schemas(self):
        if(len(self.schemas) == 0):
            self.create_definitions()
            self.create_dependency_dict()
            self.create_dependency_order()
            self.create_spark_schemas_from_definitions()
        return self.schemas
