In [None]:
import requests
import json
import uuid
from requests.auth import HTTPBasicAuth
from datetime import datetime
import logging
import csv
import pandas as pd
from io import StringIO

logger = logging.getLogger('EdFiClient')

In [None]:
#General parameters
workspace = "dev"

#EdFi specific parameters
# kvName = "kv-oea-hisddev"
# authUrl = "https://api.edgraph.dev/edfi/v5.2/saas/5eb775fb-4eff-4889-9eae-3919b7a2d321/oauth/token"
# dataManagementUrl = "https://api.edgraph.dev/edfi/v5.2/saas/data/v3/5eb775fb-4eff-4889-9eae-3919b7a2d321/2011"
# changeQueriesUrl = "https://api.edgraph.dev/edfi/v5.2/saas/changequeries/v1/5eb775fb-4eff-4889-9eae-3919b7a2d321/2011" 
# dependenciesUrl = "https://api.edgraph.dev/edfi/v5.2/saas/metadata/data/v3/5eb775fb-4eff-4889-9eae-3919b7a2d321/2011/dependencies"
# apiVersion = "5.2"
# batchLimit = 100
# moduleName = "Ed-Fi"

# minChangeVer = None
# maxChangeVer = None

# schoolYear = None
# districtId = None

# metadataUrl = "https://raw.githubusercontent.com/microsoft/OpenEduAnalytics/main/modules/module_catalog/Ed-Fi/utils/Metadata.csv"

In [None]:
%run OEA_py

In [None]:
oea.set_workspace(workspace)

In [None]:
class EdFiOEAChild(OEA):
    """ 
    NOTE: This class inherits features from the base class OEA and therefore,
    should be created / executed after running the notebook OEA_py
    """
    def __init__(self, workspace='dev', logging_level=logging.INFO, storage_account=None, keyvault=None, timezone=None):
        # Call the base class constructor to initialize inherited attributes
        super().__init__(workspace, logging_level, storage_account, keyvault, timezone)

    def upsert(self, df, destination_path, primary_key='id', partitioning=False, partitioning_cols = []):
        """ Upserts the data in the given dataframe into the specified destination using the given primary_key_column to identify the updates.
            If there is no delta table found in the destination_path, one will be created.    
        """
        destination_url = self.to_url(destination_path)
        df = self.fix_column_names(df)

        if partitioning: 
            df = df.dropDuplicates([primary_key] + partitioning_cols)
        else:
            df = df.dropDuplicates([primary_key])
        if DeltaTable.isDeltaTable(spark, destination_url):
            delta_table_sink = DeltaTable.forPath(spark, destination_url)
            
            if partitioning:
                #TODO: Generalize for arbitrary partitioning columns
                if (sorted(partitioning_cols) == ['DistrictId', 'SchoolYear']) or (len(partitioning_cols) == 0):
                    # Assumption: Each DF should have constant DistrictId and SchoolYear per run
                    partitioning_cols = ['DistrictId', 'SchoolYear']
                    if (df.select('DistrictId').first() and df.select('DistrictId').first()):
                        DistrictId = df.select('DistrictId').first()[0]
                        SchoolYear = df.select('SchoolYear').first()[0]
                        destination_partition_url = self.to_url(f"{destination_path}/DistrictId={DistrictId}/SchoolYear={SchoolYear}")
                        if DeltaTable.isDeltaTable(spark, destination_partition_url):
                            logger.info('Upsert by Partitions + PK Cols')
                            delta_table_sink.alias('sink').merge(df.alias('updates'), f'sink.DistrictId = updates.DistrictId AND sink.SchoolYear = updates.SchoolYear AND sink.{primary_key} = updates.{primary_key}').whenMatchedUpdateAll().whenNotMatchedInsertAll().execute()
                    
                    else:
                        logger.info('Dynamically over-write the partition')
                        spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
                        df.write.format('delta').mode('overwrite').partitionBy(*partitioning_cols).save(destination_url)
                else:
                    logger.info('Dynamically over-write the partition')
                    spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
                    df.write.format('delta').mode('overwrite').partitionBy(*partitioning_cols).save(destination_url)
            else:
                delta_table_sink.alias('sink').merge(df.alias('updates'), f'sink.{primary_key} = updates.{primary_key}').whenMatchedUpdateAll().whenNotMatchedInsertAll().execute()
        else:
            logger.debug('No existing delta table found. Creating delta table.')
            if not(partitioning):
                logger.info('Writing unpartitioned delta lake')
                df.write.format('delta').save(destination_url)
            elif partitioning and len(partitioning_cols) == 0:
                logger.info('Partitioning columns absent - defaulting to DistrictId and SchoolYear as partitioning columns')
                df.write.format('delta').partitionBy('DistrictId', 'SchoolYear').save(destination_url)
            else:
                partitioning_str = ', '.join(partitioning_cols)
                logger.info(f'Writing partitioned delta lake - partitioned by - {partitioning_str}')
                df.write.format('delta').partitionBy(*partitioning_cols).save(destination_url)

    def overwrite(self, df, destination_path, primary_key='id', partitioning = False, partitioning_cols = []):
        """ Overwrites the existing delta table with the given dataframe.
            If there is no delta table found in the destination_path, one will be created.    
        """
        destination_url = self.to_url(destination_path)
        df = self.fix_column_names(df)
        
        if partitioning: 
            df = df.dropDuplicates([primary_key] + partitioning_cols)
        else:
            df = df.dropDuplicates([primary_key])
        if not(partitioning):
            logger.info('Writing unpartitioned delta lake')
            df.write.format('delta').mode('overwrite').save(destination_url)
        elif partitioning and len(partitioning_cols) == 0:
            logger.info('Partitioning columns absent - defaulting to DistrictId and SchoolYear as partitioning columns')
            df.write.format('delta').mode('overwrite').partitionBy('DistrictId', 'SchoolYear').save(destination_url)
        else:
            partitioning_str = ', '.join(partitioning_cols)
            logger.info(f'Writing partitioned delta lake - partitioned by - {partitioning_str}')
            df.write.format('delta').mode('overwrite').partitionBy(*partitioning_cols).save(destination_url)
        
    def append(self, df, destination_path, primary_key='id', partitioning = False, partitioning_cols = []):
        """ Appends the given dataframe to the delta table in the specified destination.
            If there is no delta table found in the destination_path, one will be created.    
        """
        destination_url = self.to_url(destination_path)
        df = self.fix_column_names(df)

        if partitioning: 
            df = df.dropDuplicates([primary_key] + partitioning_cols)
        else:
            df = df.dropDuplicates([primary_key])

        if DeltaTable.isDeltaTable(spark, destination_url):
            df.write.format('delta').mode('append').save(destination_url)  # https://docs.delta.io/latest/delta-batch.html#append
        else:
            logger.debug('No existing delta table found. Creating delta table.')
            if not(partitioning):
                logger.info('Writing unpartitioned delta lake')
                df.write.format('delta').save(destination_url)
            elif partitioning and len(partitioning_cols) == 0:
                logger.info('Partitioning columns absent - defaulting to DistrictId and SchoolYear as partitioning columns')
                df.write.format('delta').partitionBy('DistrictId', 'SchoolYear').save(destination_url)
            else:
                partitioning_str = ', '.join(partitioning_cols)
                logger.info(f'Writing partitioned delta lake - partitioned by - {partitioning_str}')
                df.write.format('delta').partitionBy(*partitioning_cols).save(destination_url)
    
    def get_sink_general_sensitive_paths(self, source_path):
        path_dict = self.parse_path(source_path)
        
        sink_general_path = path_dict['entity_parent_path'].replace('Ingested', 'Refined') + '/general/' + path_dict['entity']
        sink_sensitive_path = path_dict['entity_parent_path'].replace('Ingested', 'Refined') + '/sensitive/' + path_dict['entity'] + '_lookup'

        return sink_general_path, sink_sensitive_path

    def refine(self, entity_path, metadata=None, primary_key='id'):
        source_path = f'stage2/Ingested/{entity_path}'
        primary_key = self.fix_column_name(primary_key) # fix the column name, in case it has a space in it or some other invalid character
        sink_general_path, sink_sensitive_path = get_sink_general_sensitive_paths(source_path)

        if not metadata:
            all_metadata = self.get_metadata_from_path(path_dict['entity_parent_path'])
            metadata = all_metadata[path_dict['entity']]
        
        df_changes = self.get_latest_changes(source_path, sink_general_path)
        spark_schema = self.to_spark_schema(metadata)
        df_changes = self.modify_schema(df_changes, spark_schema)        
        if df_changes.count() > 0:
            df_pseudo, df_lookup = self.pseudonymize(df_changes, metadata)
            self.upsert(df_pseudo, sink_general_path, f'{primary_key}_pseudonym') # todo: remove this assumption that the primary key will always be hashed during pseduonymization
            self.upsert(df_lookup, sink_sensitive_path, primary_key)    
            self.add_to_lake_db(sink_general_path)
            self.add_to_lake_db(sink_sensitive_path)
            logger.info(f'Processed {df_changes.count()} updated rows from {source_path} into stage2/Refined')
        else:
            logger.info(f'No updated rows in {source_path} to process.')
        
        return df_changes.count()

    def pseudonymize(self, df, metadata, transform_mode = False, primary_key = 'id'): #: list[list[str]]):
        """ Performs pseudonymization of the given dataframe based on the provided metadata (in the OEA format).
            For example, if the given df is for an entity called person, 
            2 dataframes will be returned, one called person that has hashed ids and masked fields, 
            and one called person_lookup that contains the original person_id, person_id_pseudo,
            and the non-masked values for columns marked to be masked.           
            The lookup table should be written to a "sensitive" folder in the data lake.
            eg, df_pseudo, df_lookup = oea.pseudonymize(df, metadata)
            [More info on this approach here: https://learn.microsoft.com/en-us/azure/databricks/security/privacy/gdpr-delta#pseudonymize-data]
        """
        salt = self._get_salt()
        df_pseudo = df
        df_lookup = df
        if not(transform_mode):
            for row in metadata:
                col_name = row[0]
                dtype = row[1]
                op = row[2]
                if op == "hash-no-lookup" or op == "hnl":
                    # This means that the lookup can be performed against a different table so no lookup is needed.
                    df_pseudo = df_pseudo.withColumn(col_name, F.sha2(F.concat(F.col(col_name), F.lit(salt)), 256)).withColumnRenamed(col_name, col_name + "_pseudonym")
                    df_lookup = df_lookup.drop(col_name)           
                elif op == "hash" or op == 'h':
                    df_pseudo = df_pseudo.withColumn(col_name, F.sha2(F.concat(F.col(col_name), F.lit(salt)), 256)).withColumnRenamed(col_name, col_name + "_pseudonym")
                    df_lookup = df_lookup.withColumn(col_name + "_pseudonym", F.sha2(F.concat(F.col(col_name), F.lit(salt)), 256))
                elif op == "mask" or op == 'm':
                    df_pseudo = df_pseudo.withColumn(col_name, F.lit('*'))
                elif op == "partition-by":
                    pass # make no changes for this column so that it will be in both dataframes and can be used for partitioning
                elif op == "no-op" or op == 'x':
                    df_lookup = df_lookup.drop(col_name)
        else:
            col_name = primary_key
            df_pseudo = df_pseudo.withColumn(col_name, F.sha2(F.concat(F.col(col_name), F.lit(salt)), 256)).withColumnRenamed(col_name, col_name + "_pseudonym")
            df_lookup = df_lookup.withColumn(col_name + "_pseudonym", F.sha2(F.concat(F.col(col_name), F.lit(salt)), 256))
        
        return (df_pseudo, df_lookup)

    
    def add_to_lake_db(self, source_entity_path, overwrite = False, extension = None):
        """ Adds the given entity as a table (if the table doesn't already exist) to the proper lake db based on the path.
            This method will also create the lake db if it doesn't already exist.
            eg: add_to_lake_db('stage2/Ingested/contoso_sis/v0.1/students')

            Note that a spark db that points to source data in the delta format can't be queried via SQL serverless pool. More info here: https://docs.microsoft.com/en-us/azure/synapse-analytics/sql/resources-self-help-sql-on-demand#delta-lake
        """
        source_dict = self.parse_path(source_entity_path)
        
        db_name = source_dict['ldb_name']
        if extension is not None:
            source_dict['entity'] = source_dict['entity'] + str(extension)

        spark.sql(f'CREATE DATABASE IF NOT EXISTS {db_name}')
        if overwrite:
            spark.sql(f"drop table if exists {db_name}.{source_dict['entity']}")

        spark.sql(f"create table if not exists {db_name}.{source_dict['entity']} using DELTA location '{self.to_url(source_dict['entity_path'])}'")

In [None]:
class EdFiClient:
    #The constructor
    def __init__(self, workspace, kvName, moduleName, authUrl, dataManagementUrl, changeQueriesUrl, dependenciesUrl, apiVersion, batchLimit, minChangeVer="", maxChangeVer="", schoolYear=None, districtId=None):
        self.workspace = workspace
        self.keyvault_linked_service = 'LS_KeyVault'
        oea.kvName = kvName

        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        for handler in logging.getLogger().handlers:
            handler.setFormatter(formatter)           
        # Customize log level for all loggers
        logging.getLogger().setLevel(logging.INFO)   
        logger.info(f"minChangeVersion={minChangeVer} and maxChangeVersion={maxChangeVer}")

        if not kvName and workspace == "dev":
            logger.info("defaulting to test data")
            self.clientId = ""
            self.clientSecret = ""
        else:
            try:
                #try to get the credentials from keyvault
                self.clientId = oea._get_secret("edfi-clientid")
                self.clientSecret = oea._get_secret("edfi-clientsecret")
            except Exception as e:
                #if there was an error getting the credentials
                #if this is the dev instance proceed with test data, otherwise raise the Exception
                logger.info(f"failed to retrieve clientId and clientSecret from keyvault with exception: {str(e)}")
                if workspace == "dev":
                    logger.info("defaulting to test data")
                    self.clientId = ""
                    self.clientSecret = ""
                else:
                    raise
        
        self.authUrl = authUrl
        self.dataManagementUrl = dataManagementUrl
        self.changeQueriesUrl = changeQueriesUrl
        self.dependenciesUrl = dependenciesUrl
        self.runDate = datetime.utcnow().strftime('%Y-%m-%d')
        self.authTime = None
        self.expiresIn = None
        self.accessToken = None
        districtPath = districtId if districtId != None else "All"
        schoolYearPath = schoolYear if schoolYear != None else "All"
        self.transactionalFolder = f"Transactional/{moduleName}/{apiVersion}/DistrictId={districtPath}/SchoolYear={schoolYearPath}"
        self.batchLimit = batchLimit
        self.minChangeVer = minChangeVer
        self.maxChangeVer = maxChangeVer

    #Method to get the access token for the test data set
    def authenticateWithAuthorization(self):
        #TODO: need to update this if we want it to work with other edfi provided test data set versions
        result = requests.post("https://api.ed-fi.org/v5.2/api/oauth/token",{"grant_type":"client_credentials"},headers={"Authorization":"Basic UnZjb2hLejl6SEk0OkUxaUVGdXNhTmY4MXh6Q3h3SGZib2xrQw=="})
        return result

    #Method to get the access token for a production system with basic auth
    def authenticateWithBasic(self):
        authHeader = HTTPBasicAuth(self.clientId, self.clientSecret)
        result = requests.post(self.authUrl,{"grant_type":"client_credentials"},auth=authHeader)
        return result

    #This method orchestrates the authentication
    def authenticate(self):
        self.authTime = datetime.now()
        if not self.clientId or not self.clientSecret: #self.workspace == "dev":
            result = self.authenticateWithAuthorization().json()
            logger.info(result)
        else:
            result = self.authenticateWithBasic().json()
        self.expiresIn = result["expires_in"]
        self.accessToken = result["access_token"]
    
    #This method manages the access token, refreshing it when required
    def getAccessToken(self):
        currentTime = datetime.now()
        #Get a new access token if none exists, or if the expires time is within 5 minutes of expiry
        if self.accessToken == None or (currentTime-self.authTime).total_seconds() > self.expiresIn - 300:
            self.authenticate()
            return self.accessToken
        else:
            return self.accessToken 

    def getChangeQueryVersion(self):
        access_token = self.getAccessToken()
        response = requests.get(changeQueriesUrl + "/availableChangeVersions", headers={"Authorization":"Bearer " + access_token})
        return response.json()
    
    def getEntities(self):
        return requests.get(self.dependenciesUrl).json()

    def getDeletes(self,resource, minChangeVersion, maxChangeVersion):
        url = f"{self.dataManagementUrl}{resource}/deletes?MinChangeVersion={minChangeVersion}&MaxChangeVersion={maxChangeVersion}"
        result = requests.get(url,headers = {"Authorization": f"Bearer {self.getAccessToken()}"})
        return result

    def writeToDeletesFile(self, resource, deletes):
        path = f"stage1/{self.transactionalFolder}{resource}/delete_batch_data/rundate={self.runDate}/data.json"
        mssparkutils.fs.put(oea.to_url(path),deletes.text)

    def landEntities(self, entities = 'All'):
        if entities == 'All':
            entities = self.getEntities()
        else:
            entities = self.getSpecifiedEntities(entities)
        changeVersion = self.getChangeQueryVersion()
        minChangeVersion = changeVersion['OldestChangeVersion'] if self.minChangeVer == None else int(self.minChangeVer)
        maxChangeVersion = changeVersion['NewestChangeVersion']  if self.maxChangeVer == None else int(self.maxChangeVer)
        for entity in entities:
            resource = entity['resource']
            resourceMinChangeVersion = self.getChangeVersion(resource, minChangeVersion) if self.minChangeVer == None else minChangeVersion

            self.landEntity(resource, resourceMinChangeVersion, maxChangeVersion)
            deletes = self.getDeletes(resource,resourceMinChangeVersion,maxChangeVersion)
            if len(deletes.json()):
                self.writeToDeletesFile(resource,deletes)
    
    def getChangeVersion(self, resource, default):
        path = f"stage1/{self.transactionalFolder}{resource}/changeFile.json"
        if mssparkutils.fs.exists(oea.to_url(path)):
            return json.loads(mssparkutils.fs.head(oea.to_url(path)))['changeVersion']
        else:
            return default

    def landEntity(self,resource,minChangeVersion,maxChangeVersion):
        logger.info(f"initiating {resource}")
        path = f"stage1/{self.transactionalFolder}{resource}"
        url = f"{self.dataManagementUrl}{resource}?MinChangeVersion={minChangeVersion}&MaxChangeVersion={maxChangeVersion}&totalCount=true"
        total_count_response = requests.get(url, headers={"Authorization":f"Bearer {self.getAccessToken()}"})
        try:
            #Keyset pagination implementation: https://techdocs.ed-fi.org/display/ODSAPIS3V61/Improve+Paging+Performance+on+Large+API+Resources
            
            #split into the total number of partitions, and the range size
            total_count = int(total_count_response.headers["Total-Count"])
            partitions = total_count // self.batchLimit 

            #raise(ValueError('ERROR'))
            if(total_count == 0 and partitions == 0):
                logger.info(f'No new / updated items b/w the following versions {minChangeVersion} and {maxChangeVersion}')
            else:
                range_size = maxChangeVersion // partitions
                for i in range(partitions + 1):
                    #calculate the min and max change version for the partition
                    partitionMinChangeVersion = i*range_size
                    partitionMaxChangeVersion = min(maxChangeVersion, (i+1)*range_size)

                    #Calculate the number of batches per partition
                    partitionUrl=f"{self.dataManagementUrl}{resource}?MinChangeVersion={partitionChangeVersion}&MaxChangeVersion={partitionChangeVersion}&totalCount=true"
                    partition_count_response = requests.get(partitionUrl, headers={"Authorization":f"Bearer {self.getAccessToken()}"})
                    partition_count = int(partition_count_response.headers["Total-Count"])
                    batches = partition_count // self.batchLimit

                    for j in range(batches + 1):
                        batchUrl=f"{partitionUrl}&limit={self.batchLimit}&offset={(j)*self.batchLimit}"
                        data = requests.get(batch_url, headers={"Authorization":f"Bearer {self.getAccessToken()}"}) 
                        if(data.status_code < 400):         
                            filepath = f"{path}/delta_batch_data/rundate={self.runDate}/data{uuid.uuid4()}.json"
                            output = json.loads(data.text)
                            output_string = ""
                            for line in output:
                                output_string += json.dumps(line) + "\n"
                            mssparkutils.fs.put(oea.to_url(filepath),output_string)
                        else:
                            logger.info(f"There was an error retrieving batch data for {resource}")
        except:
            data = requests.get(url, headers={"Authorization":f"Bearer {self.getAccessToken()}"})          
            #print(data.text)
            if(data.status_code < 400):         
                filepath = f"{path}/delta_batch_data/rundate={self.runDate}/data{uuid.uuid4()}.json"
                output = json.loads(data.text)
                if(len(output) == 0):
                    logger.info(f'No new / updated items b/w the following versions {minChangeVersion} and {maxChangeVersion}')
                else:
                    output_string = ""
                    for line in output:
                        output_string += json.dumps(line) + "\n"
                    mssparkutils.fs.put(oea.to_url(filepath),output_string)
            else:
                logger.info(f"There was an error retrieving data for {resource}")
    
        changeFilepath = f"{path}/changeFile.json"
        changeData = {"changeVersion":maxChangeVersion}
        mssparkutils.fs.put(oea.to_url(changeFilepath),json.dumps(changeData),True)
        logging.info(f"completed {resource}")
    
    def parse_text_to_dataframe(self, text_content, delimiter=','):
        csv_file = StringIO(text_content)
        df = pd.read_csv(csv_file, delimiter=delimiter) 
        
        return df

    def extract_entities_for_etl(self, df):
        concat_list = []
        entity_names_list = []
        
        for index, row in df.iterrows():
            entity_type = row['entity_type']
            entity_name = row['entity_name']
            
            if entity_type != 'ed-fi':
                concat_list.append(f'/{entity_type}/{entity_name}')
            
            concat_list.append(f'/ed-fi/{entity_name}')
            entity_names_list.append(entity_name)
        
        return concat_list, list(set(entity_names_list))


    def getSpecifiedEntities(self, entities_list):
        data = self.getEntities()
        entities = [item for item in data if item['resource'] in entities_list]
        return entities

    def listSpecifiedEntities(self, path): 
        fullpath = path + '/entities-to-extract.csv'
        pathExists = oea.path_exists(fullpath)
        if pathExists:
            csv_str = oea.get_text_from_path(fullpath)
            csv_pd_df = self.parse_text_to_dataframe(csv_str, delimiter=',')
            api_entities, entities = self.extract_entities_for_etl(csv_pd_df)
        else:
            api_entities = list()
            entities = list()
        return api_entities, entities      

In [None]:
oea = EdFiOEAChild()          
edfi = EdFiClient(workspace, 
                  kvName, 
                  moduleName, 
                  authUrl, 
                  dataManagementUrl, 
                  changeQueriesUrl, 
                  dependenciesUrl, 
                  apiVersion, 
                  batchLimit, 
                  minChangeVer, 
                  maxChangeVer)