In [None]:
####### OEA configuration #############
oea_storage_account = 'yourstorageaccount'
oea_keyvault = 'yourkeyvault'
oea_timezone = 'US/Eastern'
#######################################

from delta.tables import DeltaTable
from notebookutils import mssparkutils
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, ArrayType, TimestampType, BooleanType, ShortType, DateType
from pyspark.sql import functions as F
from pyspark.sql import SparkSession
from pyspark.sql.utils import AnalysisException
import logging
import pandas as pd
import sys
import re
import json
import datetime
import pytz
import random
import io
import urllib.request

logger = logging.getLogger('OEA')

class OEA:
    """ OEA (Open Education Analytics) framework simplifies the process of working with large data sets within the context of a lakehouse architecture.
        Definition of terms used throughout this codebase:
        path - a complete or partial folder or file path (does not include details like scheme or domain name as found in a URL). Ex: contosos/v0.1/students
        entity_path - a path that ends with a folder that contains entity data. Ex: contoso/v0.1/students
        dataset_path - a path that ends with a folder that contains entity folders (entity parent folder). Ex: contoso/v0.1
        url - includes scheme and domain name. Ex: abfss://stage1@storageaccount.dfs.core.windows.net/contoso/v0.1/students

    """
    DELTA_BATCH_DATA = 'delta_batch_data'
    ADDITIVE_BATCH_DATA = 'additive_batch_data'
    SNAPSHOT_BATCH_DATA = 'snapshot_batch_data'

    def __init__(self, workspace='dev', logging_level=logging.INFO, storage_account=None, keyvault=None, timezone=None):
        self.keyvault_linked_service = 'LS_KeyVault'
        self.salt_secret_name = 'oeaSalt'
        self.salt = None
        self.workspace = workspace
        self.storage_account = oea_storage_account
        self.keyvault = oea_keyvault
        self.timezone = oea_timezone

        # pull in override values if any were passed in
        if workspace: self.workspace = workspace
        if storage_account: self.storage_account = storage_account
        if keyvault: self.keyvault = keyvault 
        if timezone: self.timezone = timezone
        if logging_level: self.logging_level = logging_level    

        self._initialize_logger(logging_level)
        self.set_workspace(self.workspace)
        spark.conf.set("spark.microsoft.delta.optimizeWrite.enabled", "true") # more info here: https://learn.microsoft.com/en-us/azure/synapse-analytics/spark/optimize-write-for-apache-spark
        logger.info("OEA initialized.")

    def _initialize_logger(self, logging_level):
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        for handler in logging.getLogger().handlers:
            handler.setFormatter(formatter)           
        # Customize log level for all loggers
        logging.getLogger().setLevel(logging_level)        

    def _get_secret(self, secret_name):
        """ Retrieves the specified secret from the keyvault.
            This method assumes that the keyvault linked service has been setup and is accessible.
        """
        sc = SparkSession.builder.getOrCreate()
        token_library = sc._jvm.com.microsoft.azure.synapse.tokenlibrary.TokenLibrary
        value = token_library.getSecret(self.keyvault, secret_name, self.keyvault_linked_service)        
        return value

    def _get_salt(self):
        if not self.salt:
            self.salt = self._get_secret(self.salt_secret_name)
        return self.salt

    def set_workspace(self, workspace_name):
        """ Allows you to use OEA against your workspace
            (eg, you specify Jon as workspace_name, then instead of reading in from stage1 OEA will use workspace/Jon/stage1
        """
        
        if workspace_name == 'prod' or workspace_name == 'production':
            self.workspace = 'prod'
            self.stage1 = 'abfss://stage1@' + self.storage_account + '.dfs.core.windows.net'
            self.stage2 = 'abfss://stage2@' + self.storage_account + '.dfs.core.windows.net'
            self.stage3 = 'abfss://stage3@' + self.storage_account + '.dfs.core.windows.net'
        elif workspace_name == 'dev' or workspace_name == 'development':
            self.workspace = 'dev'
            self.stage1 = f'abfss://oea@{self.storage_account}.dfs.core.windows.net/dev/stage1'
            self.stage2 = f'abfss://oea@{self.storage_account}.dfs.core.windows.net/dev/stage2'
            self.stage3 = f'abfss://oea@{self.storage_account}.dfs.core.windows.net/dev/stage3'
        else:
            self.workspace = workspace_name
            self.stage1 = f'abfss://oea@{self.storage_account}.dfs.core.windows.net/sandboxes/{workspace_name}/stage1'
            self.stage2 = f'abfss://oea@{self.storage_account}.dfs.core.windows.net/sandboxes/{workspace_name}/stage2'
            self.stage3 = f'abfss://oea@{self.storage_account}.dfs.core.windows.net/sandboxes/{workspace_name}/stage3'
        logger.info(f'Now using workspace: {self.workspace}')

    def to_url(self, path):
        """ Converts the given path into a valid url.
            eg, convert_path('stage1/contoso_sis/student') # returns abfss://stage1@storageaccount.dfs.core.windows.net/contoso_sis/student
            [Note that the url returned will include the sandbox location if a workspace has been set; for example, abfss://oea@storageaccount.dfs.core.windows.net/sandboxes/sam/stage1/contoso_sis/student]
        """
        if not path or path == '': raise ValueError('Specified path cannot be empty.')
        if path.startswith('abfss://'): return path # if a url is given, just return that same url (allows to_url to be invoked just in case translation may be needed)
        path_args = path.split('/')
        stage = path_args.pop(0)
        if stage == 'stage1': stage = self.stage1
        elif stage == 'stage2': stage = self.stage2
        elif stage == 'stage3': stage = self.stage3
        else: raise ValueError("Path must begin with either 'stage1', 'stage2', or 'stage3'")
        url = f"{stage}/{'/'.join(path_args)}"
        logger.debug(f'to_url: {url}')
        return url      

    def parse_path(self, path):
        """ Parses a path that looks like one of the following:
                ms_insights/v0.1
                ms_insights/v0.1/students

                stage1/Transactional/ms_insights/v0.1
                stage1/Transactional/ms_insights/v0.1/students
            (the path must either be the path to a specific entity, or the path to the parent folder containing entities)
            and returns a dictionary like one of the following:
                {'stage': 'stage1', 'stage_num': '1', 'category': 'Transactional', 'source_system': 'contoso_sis', 'entity': None, 'entity_list': ['studentattendance'], 'entity_path': None, 'entity_parent_path': 'stage1/Transactional/contoso_sis/v0.1'}
                {'stage': 'stage1', 'stage_num': '1', 'category': 'Transactional', 'source_system': 'contoso_sis', 'entity': 'studentattendance', 'entity_list': None, 'entity_path': 'stage1/Transactional/contoso_sis/v0.1/studentattendance', 'entity_parent_path': 'stage1/Transactional/contoso_sis/v0.1'}

            This method assumes the standard OEA data lake, in which paths have this structure: <stage number>/<category>/<source system>/<optional version and partitioning>/<entity>/<either batch_data folder or _delta_log>
        """
        if type(path) is dict: return path # this means the path was already parsed
        
        ar = path.split('/')
        path_dict = {'stage':ar[0], 'stage_num':ar[0][-1], 'category':ar[1], 'source_system':ar[2], 'entity':None, 'entity_list':None, 'entity_path':None, 'entity_parent_path':None}

        folders = self.get_folders(self.to_url(path))

        # Identify an entity folder by the presence of the "_delta_log" folder in stage2 and stage3
        if (path_dict['stage_num'] == '1' and ('additive_batch_data' in folders or 'delta_batch_data' in folders or 'snapshot_batch_data' in folders)) or ((path_dict['stage_num'] == '2' or path_dict['stage_num'] == '3') and '_delta_log' in folders):
            path_dict['entity'] = ar[-1]
            path_dict['entity_path'] = path
            path_dict['entity_parent_path'] = '/'.join(ar[0:-1]) # eg, stage1/Transactional/contoso_sis/v0.1
        else:
            path_dict['entity_list'] = folders
            path_dict['entity_parent_path'] = path

        if path_dict['stage'] == 'stage2':
            abbrev = path_dict['category'][0].lower() # either 'i' for Ingested or 'r' for Refined
            path_dict['sdb_name'] = f'sdb_{self.workspace}_s{path_dict["stage_num"]}{abbrev}_{path_dict["source_system"].lower()}' # name of the sql db for this source (use lower case to match the naming for lake db)
            path_dict['ldb_name'] = f'ldb_{self.workspace}_s{path_dict["stage_num"]}{abbrev}_{path_dict["source_system"].lower()}' # name of the lake db for this source (spark will automatically lower case the name of the db, but we're doing it here to be explicit)
        else:
            path_dict['sdb_name'] = f'sdb_{self.workspace}_s{path_dict["stage_num"]}_{path_dict["source_system"].lower()}' # name of the sql db for this source (use lower case to match the naming for lake db)
            path_dict['ldb_name'] = f'ldb_{self.workspace}_s{path_dict["stage_num"]}_{path_dict["source_system"].lower()}' # name of the lake db for this source (spark will automatically lower case the name of the db, but we're doing it here to be explicit)
        
        path_dict['between_path'] = '/'.join(path_dict['entity_parent_path'].split('/')[2:]) # strip off the first 2 args in the entity parent path (eg, strip off stage1/Transactional which leaves contoso_sis/v0.1)

        m = re.match(r'.*\/(v[^\/]+).*', path_dict['between_path'])
        if m:
            path_dict['version'] = m.group(1)
            # Append the version number to the db names. First replace the '.' char with a 'p' if necessary (because a '.' is not allowed in the db name)
            safe_version = re.sub('\.', 'p', path_dict["version"])
            path_dict['sdb_name'] = f'{path_dict["sdb_name"]}_{safe_version}'
            path_dict['ldb_name'] = f'{path_dict["ldb_name"]}_{safe_version}'
        else:
            path_dict['version'] = None

        return path_dict
    

    def rm_if_exists(self, path, recursive_remove=True):
        """ Remove a folder if it exists (defaults to use of recursive removal). 
            eg. rm_if_exists('stage1/Transactional/contoso_sis/v0.1/students') 
            will delete the folder 'stage1/Transactional/contoso_sis/v0.1/students'.
            Note that the path should always start with stage1/stage2/stage3/oea.
        """
        try:
            mssparkutils.fs.rm(self.to_url(path), recursive_remove)
        except Exception as e:
            pass

    def delete(self, path):
        """ Delete a folder and everything in it. """
        self.rm_if_exists(path, True)

    def ls(self, path):
        """ List the contents of the given path. """
        url = self.to_url(path)
        folders = []
        files = []
        try:
            items = mssparkutils.fs.ls(url)
            for item in items:
                if item.isFile:
                    files.append(item.name)
                elif item.isDir:
                    folders.append(item.name)
        except Exception as e:
            logger.warning("[OEA] Could not peform ls on specified path: " + path + "\nThis may be because the path does not exist.")
        return (folders, files)

    def path_exists(self, path):
        """ Returns true if path exists, false if it doesn't (no exception will be thrown). 
            eg, path_exists('stage1/mytest/v1.0')
        """
        try:
            items = mssparkutils.fs.ls(self.to_url(path))
        except Exception as e:
            # This Exception comes as a generic Py4JJavaError that occurs when the path specified is not found.
            return False
        return True

    def get_stage_num(self, path):
        """ Returns the stage number of the given path """
        m = re.match(r'.*stage(\d)/.*', path)
        if m:
            return m.group(1)
        else:
            raise ValueError("Path must begin with either 'stage1', 'stage2', or 'stage3'")

    def get_folders(self, path):
        """ Return the list of folders found in the given path. """
        dirs = []
        try:
            items = mssparkutils.fs.ls(self.to_url(path))
            for item in items:
                #print(item.name, item.isDir, item.isFile, item.path, item.size)
                if item.isDir:
                    dirs.append(item.name)
        except Exception as e:
            logger.warning("[OEA] Could not get list of folders in specified path: " + path + "\nThis may be because the path does not exist.")
        return dirs

    def get_latest_folder(self, path):
        """ Gets the last folder listed in the given path. """
        folders = self.get_folders(path)
        if len(folders) > 0: return folders[-1]
        else: return None

    def contains_batch_folder(self, path):
        """ Returns True if the given folder contains any OEA batch folders in it, else False"""
        for name in self.get_folders(self.to_url(path)):
            if name == 'additive_batch_data' or name == 'snapshot_batch_data' or name == 'delta_batch_data':
                return True
        return False

    def get_batch_info(self, source_path):
        """ Given a source data path, returns a tuple with the batch type (based on the name of the folder) and file type (based on a file extension) 
            eg, get_batch_info('stage1/Transactional/sis/v1.0/students') # returns ('snapshot', 'csv')
        """
        url = self.to_url(source_path)
        source_folder_name = self.get_latest_folder(url) #expects to find one of: additivie_batch_data, snapshot_batch_data, delta_batch_data
        batch_type = source_folder_name.split('_')[0]

        rundate_dir = self.get_latest_folder(f'{url}/{source_folder_name}')
        data_files = self.ls(f'{url}/{source_folder_name}/{rundate_dir}')[1]
        file_extension = data_files[0].split('.')[1]
        return batch_type, file_extension        

    def load(self, path):
        """ Loads the DELTA table from the given path into a dataframe and returns it """
        df = spark.read.format('delta').load(self.to_url(path))
        return df        

    def display(self, path, limit=4):
        """ Displays the DELTA table from the given path in tabular format. 
            Default limit is 4 """
        df = spark.read.format('delta').load(self.to_url(path))
        display(df.limit(limit))
        return df

    def show(self, path, limit=4):
        """ Performs a df.show() operation on the dataframe by loading the DELTA table from the given path.
            Default limit is 4 """
        df = spark.read.format('delta').load(self.to_url(path))
        df.show(limit)
        return df

    def fix_column_names(self, df):
        """ Fix column names to satisfy the Parquet naming requirements by substituting invalid characters with an underscore. """
        df_with_valid_column_names = df.select([F.col(col).alias(self.fix_column_name(col)) for col in df.columns])
        return df_with_valid_column_names

    def fix_column_name(self, column_name):
        """ Replace illegal characters from a given column name based on DELTA naming conventions by underscores """
        return re.sub("[ ,;{}()\n\t=]+", "_", column_name) 

    def to_spark_schema(self, schema):#: list[list[str]]):
        """ Creates a spark schema from a schema specified in the OEA schema format. 
            Example:
            schemas['Person'] = [['Id','string','hash'],
                                    ['CreateDate','timestamp','no-op'],
                                    ['LastModifiedDate','timestamp','no-op']]
            to_spark_schema(schemas['Person'])
        """
        fields = []
        for col_name, dtype, op in schema:
            fields.append(StructField(col_name, globals()[dtype.lower().capitalize() + "Type"](), True))
        spark_schema = StructType(fields)
        return spark_schema

    def get_text_from_path(self, path):
        """ Returns text contents from a given path 
            eg: get_text_from_path('stage1/Transactional/contoso_sis/v0.1/students/part1.csv')
        """
        txt = mssparkutils.fs.head(oea.to_url(path), 9000000)
        return txt

    def get_text_from_url(self, url):
        """ Retrieves the text doc at the given url. 
            eg: get_text_from_url("https://raw.githubusercontent.com/microsoft/OpenEduAnalytics/modules/module_catalog/Student_and_School_Data_Systems/metadata.csv")
        """
        response = urllib.request.urlopen(url)
        txt = response.read().decode('utf-8')  
        return txt

    def get_metadata_from_url(self, url):
        """ Returns the Metadata objects by retrieving the contents of the metadata file from given URL and parsing it.
            eg: get_metadata_from_url('https://raw.githubusercontent.com/microsoft/OpenEduAnalytics/gene/v0.7dev/modules/module_catalog/Student_and_School_Data_Systems/metadata.csv')
        """
        csv_str = self.get_text_from_url(url)
        metadata = self.parse_metadata_from_csv(csv_str)
        return metadata   

    def get_metadata_from_path(self, path):
        """ Returns the Metadata dictionary by retrieving the contents of the metadata file from given path and parsing it.
            eg: get_metadata_from_path('stage1/Transactional/contoso_sis/v0.1/metadata.csv')
        """ 
        csv_str = self.get_text_from_path(path + '/metadata.csv')
        metadata = self.parse_metadata_from_csv(csv_str)
        return metadata                   

    def land_metadata_from_url(self, metadata_url, dataset_path):
        """ Retrieve metadata contents from given URL and write it to file named metadata.csv in the destination path.
            eg: land_metadata_from_url('https://raw.githubusercontent.com/microsoft/OpenEduAnalytics/gene/v0.7dev/modules/module_catalog/Student_and_School_Data_Systems/metadata.csv', 'contoso_sis/v0.1')
            Note: Do not include the file name in the dataset_path parameter. 
        """
        metadata_str = self.get_text_from_url(metadata_url)
        self.write(metadata_str, self._metadata_path(dataset_path))

    def _metadata_path(self, dataset_path):
        return f'stage2/Ingested/{dataset_path}/metadata.csv'

    def parse_metadata_from_csv(self, csv_str):
        """ Parses out metadata from a csv string and returns the metadata dictionary. 
        """
        metadata = {}
        current_entity = ''
        header = None
        for line in csv_str.splitlines():
            line = line.strip()
            # skip empty lines, lines that start with # (because these are comments), and lines with only commas (which is what happens if someone uses excel and leaves a row blank) 
            if len(line) == 0 or line.startswith('#') or re.match(r'^,+$', line): continue
            ar = line.split(',')

            if not header:
                header = []
                for column_name in ar:
                    header.append(re.sub("[ ,;{}()\n\t=]+", "_", column_name))
                continue
            
            # check for the start of a new entity definition
            if ar[0] != '':
                current_entity = ar[0]
                metadata[current_entity] = []
            # an attribute row must have an attribute name in the second column
            elif len(ar[1]) > 0:
                ar = ar[1:] # remove the first element because it will be blank
                ar[0] = self.fix_column_name(ar[0]) # remove spaces and other illegal chars from column names
                metadata[current_entity].append(ar)
            else:
                logger.info('Invalid metadata row: ' + line)
        return metadata

    def write(self, data_str, destination_path_and_filename):
        """ Writes the given data string to a file on blob storage """
        destination_url = self.to_url(destination_path_and_filename)
        mssparkutils.fs.put(destination_url, data_str, True) # Set the last parameter as True to create the file if it does not exist    

    def create_run_date(self):
        rundate = datetime.datetime.now().replace(microsecond=0) # use UTC for the datetime because when parsing it out later, spark's to_timestamp() assumes the local machine's timezone, and the timezone for the spark cluster will be UTC
        return rundate

    def land(self, data, entity_path, filename, batch_data_type=DELTA_BATCH_DATA, rundate=None):
        """ Lands data in the given entity_path, adding a rundate folder.
            eg, land(data, 'contoso/v0.1/students', 'students.csv', oea.SNAPSHOT_BATCH_DATA)
        """
        if not rundate: rundate = self.create_run_date()
        sink_path = f'stage1/Transactional/{entity_path}/{batch_data_type}/rundate={rundate}/{filename}'
        self.write(data, sink_path)
        return sink_path                  

    def upsert(self, df, destination_path, primary_key='id'):
        """ Upserts the data in the given dataframe into the specified destination using the given primary_key_column to identify the updates.
            If there is no delta table found in the destination_path, one will be created.    
        """
        destination_url = self.to_url(destination_path)
        df = self.fix_column_names(df)
        if DeltaTable.isDeltaTable(spark, destination_url):
            delta_table_sink = DeltaTable.forPath(spark, destination_url)
            #delta_table_sink.alias('sink').option('mergeSchema', 'true').merge(df.alias('updates'), f'sink.{primary_key} = updates.{primary_key}').whenMatchedUpdateAll().whenNotMatchedInsertAll()
            delta_table_sink.alias('sink').merge(df.alias('updates'), f'sink.{primary_key} = updates.{primary_key}').whenMatchedUpdateAll().whenNotMatchedInsertAll().execute()
        else:
            logger.debug('No existing delta table found. Creating delta table.')
            df.write.format('delta').save(destination_url)

    def overwrite(self, df, destination_path):
        """ Overwrites the existing delta table with the given dataframe.
            If there is no delta table found in the destination_path, one will be created.    
        """
        destination_url = self.to_url(destination_path)
        df = self.fix_column_names(df)
        df.write.format('delta').mode('overwrite').save(destination_url)  # https://docs.delta.io/latest/delta-batch.html#overwrite        

    def append(self, df, destination_path):
        """ Appends the given dataframe to the delta table in the specified destination.
            If there is no delta table found in the destination_path, one will be created.    
        """
        destination_url = self.to_url(destination_path)
        df = self.fix_column_names(df)
        if DeltaTable.isDeltaTable(spark, destination_url):
            df.write.format('delta').mode('append').save(destination_url)  # https://docs.delta.io/latest/delta-batch.html#append
        else:
            logger.debug('No existing delta table found. Creating delta table.')
            df.write.format('delta').save(destination_url)

    def process(self, source_path, foreach_batch_function, options={}):
        """ This simplifies the process of using structured streaming when processing transformations.
            Provide a source_path and a function that receives a dataframe to work with (which will be a dataframe with data from the given source_path).
            Use it like this...
            def refine_contoso_dataset(df_source):
                metadata = oea.get_metadata_from_url('https://raw.githubusercontent.com/microsoft/OpenEduAnalytics/gene/v0.7dev/modules/module_catalog/Student_and_School_Data_Systems/metadata.csv')
                df_pseudo, df_lookup = oea.pseudonymize(df, metadata['studentattendance'])
                oea.upsert(df_pseudo, 'stage2/Refined/contoso_sis/v0.1/studentattendance/general')
                oea.upsert(df_lookup, 'stage2/Refined/contoso_sis/v0.1/studentattendance/sensitive')
            oea.process('stage2/Ingested/contoso_sis/v0.1/studentattendance', refine_contoso_dataset)             
        """
        if not self.path_exists(source_path):
            raise ValueError(f'The given path does not exist: {source_path} (which resolves to: {self.to_url(source_path)})') 

        def wrapped_function(df, batch_id):
            df.persist() # cache the df so it doesn't get read in multiple times when we write to multiple destinations. See: https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#foreachbatch
            foreach_batch_function(df)
            df.unpersist()

        spark.sql("set spark.sql.streaming.schemaInference=true")
        streaming_df = spark.readStream.format('delta').load(self.to_url(source_path), **options)
        # for more info on append vs complete vs update modes for structured streaming: https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#basic-concepts
        query = streaming_df.writeStream.format('delta').outputMode('append').trigger(once=True).option('checkpointLocation', self.to_url(source_path) + '/_checkpoints').foreachBatch(wrapped_function).start()
        query.awaitTermination()   # block until query is terminated, with stop() or with error; A StreamingQueryException will be thrown if an exception occurs.
        number_of_new_inbound_rows = query.lastProgress["numInputRows"]
        logger.info(f'Number of new inbound rows processed: {number_of_new_inbound_rows}')
        logger.debug(query.lastProgress)
        return number_of_new_inbound_rows

    def is_entity_path():
        # @Gene: I Assume this was for testing purposes, Can we remove it?
        return false
    
    def ingest_all(self, dataset_path, primary_key='id', options={}):
        """ Ingests all the entities in the given source_path.
            CSV files are expected to have a header row by default, and JSON files are expected to have complete JSON docs on each row in the file.
            To specify options that are different from these defaults, use the options param.
            eg, ingest('contoso_sis/v0.1') # ingests all entities found in that path
            eg, ingest('contoso_sis/v0.1', options={'header':False}) # for CSV files that don't have a header        
        """
        folders = self.get_folders(self.to_url(f'stage1/Transactional/{dataset_path}'))
        number_of_new_inbound_rows = 0
        for entity_name in folders:
            number_of_new_inbound_rows += self.ingest(f'{dataset_path}/{entity_name}', primary_key, options)
        return number_of_new_inbound_rows

    def ingest(self, entity_path, primary_key='id', options={}):
        """ Ingests the data for the entity in the given path.
            CSV files are expected to have a header row by default, and JSON files are expected to have complete JSON docs on each row in the file.
            To specify options that are different from these defaults, use the options param.
            eg, ingest('contoso_sis/v0.1/students') # ingests all entities found in that path
            eg, ingest('contoso_sis/v0.1/students', options={'header':False}) # for CSV files that don't have a header
        """
        primary_key = self.fix_column_name(primary_key) # fix the column name, in case it has a space in it or some other invalid character
        ingested_path = f'stage2/Ingested/{entity_path}'
        raw_path = f'stage1/Transactional/{entity_path}'
        batch_type, source_data_format = self.get_batch_info(raw_path)
        logger.info(f'Ingesting from: {raw_path}, batch type of: {batch_type}, source data format of: {source_data_format}')
        source_url = self.to_url(f'{raw_path}/{batch_type}_batch_data')

        if batch_type == 'snapshot': source_url = f'{source_url}/{self.get_latest_folder(source_url)}' 
            
        logger.debug(f'Processing {batch_type} data from: {source_url} and writing out to: {ingested_path}')
        if batch_type == 'snapshot':
            def batch_func(df): self.overwrite(df, ingested_path)
        elif batch_type == 'additive':
            def batch_func(df): self.append(df, ingested_path)
        elif batch_type == 'delta':
            def batch_func(df): self.upsert(df, ingested_path, primary_key)
        else:
            raise ValueError("No valid batch folder was found at that path (expected to find a single folder with one of the following names: snapshot_batch_data, additive_batch_data, or delta_batch_data). Are you sure you have the right path?")                      

        if options == None: options = {}
        options['format'] = source_data_format # eg, 'csv', 'json'
        if source_data_format == 'csv' and (not 'header' in options or options['header'] == None): options['header'] = True  # default to expecting a header in csv files

        number_of_new_inbound_rows = self.process(source_url, batch_func, options)
        if number_of_new_inbound_rows > 0:    
            self.add_to_lake_db(ingested_path)
        return number_of_new_inbound_rows

    def query(self, source_path, query_str, criteria_str=None):
        """ Returns a dataframe which is the obtained by running the given SQL query on the data in the source_path.
        """
        df = self.load(source_path)
        sqlContext.registerDataFrameAsTable(df, 'tmp_source_table')
        if criteria_str:
            query = f'{query_str} from tmp_source_table where {criteria_str}'
        else:
            query = f'{query_str} from tmp_source_table'
        df = sqlContext.sql(query)
        return df       

    def get_latest_changes(self, source_path, sink_path):
        """ Returns a dataframe representing the changes in the source data based on the max rundate in the sink data. 
            If the sink path is not found, all of the data from the source_path is returned (the assumption is that the sink delta table is being created for the first time).
            eg, get_latest_changes('stage2/Ingested/contoso/v0.1/students', 'stage2/Refined/contoso/v0.1/students')
        """   
        maxdatetime = None
        try:
            sink_df = self.query(sink_path, 'select max(rundate) maxdatetime')
            maxdatetime = sink_df.first()['maxdatetime']
        except AnalysisException as e:
            # This means that there is no delta table at the sink_path yet.
            # We'll assume that the sink delta table is being created for the first time, meaning that all of the source data should be returned.
            pass

        changes_df = self.load(source_path)
        if maxdatetime:
            # filter the source table for the latest changes (using the max rundate in the destination table as the watermark)
            changes_df = changes_df.where(f"rundate > '{maxdatetime}'")        
        return changes_df

    def refine_all(self, dataset_path, metadata=None, primary_key='id'):
        """ Refines all the entities in the given source_path.
            CSV files are expected to have a header row by default, and JSON files are expected to have complete JSON docs on each row in the file.
            To specify options that are different from these defaults, use the options param.
            eg, ingest('contoso_sis/v0.1') # ingests all entities found in that path
            eg, ingest('contoso_sis/v0.1', options={'header':False}) # for CSV files that don't have a header        
        """
        folders = self.get_folders(self.to_url(f'stage2/Ingested/{dataset_path}'))
        number_of_new_inbound_rows = 0
        for entity_name in folders:
            number_of_new_inbound_rows += self.refine(f'{dataset_path}/{entity_name}', metadata, primary_key)
        return number_of_new_inbound_rows

    def refine(self, source_path, metadata=None, primary_key='id'):
        """ Refines the data for an entity in the Ingested folder and writes it to the Refined folder.
            This method performs the pseudonymization of the ingested folder and writes to General and Sensitive folders under Refined.
        """
        source_path = f'stage2/Ingested/{source_path}'
        primary_key = self.fix_column_name(primary_key) # fix the column name, in case it has a space in it or some other invalid character
        path_dict = self.parse_path(source_path)
        sink_general_path = path_dict['entity_parent_path'].replace('Ingested', 'Refined') + '/general/' + path_dict['entity']
        sink_sensitive_path = path_dict['entity_parent_path'].replace('Ingested', 'Refined') + '/sensitive/' + path_dict['entity'] + '_lookup'
        if not metadata:
            all_metadata = self.get_metadata_from_path(path_dict['entity_parent_path'])
            metadata = all_metadata[path_dict['entity']]

        df_changes = self.get_latest_changes(source_path, sink_general_path)

        if df_changes.count() > 0:
            df_pseudo, df_lookup = self.pseudonymize(df_changes, metadata)
            self.upsert(df_pseudo, sink_general_path, f'{primary_key}_pseudonym') # todo: remove this assumption that the primary key will always be hashed during pseduonymization
            self.upsert(df_lookup, sink_sensitive_path, primary_key)    
            self.add_to_lake_db(sink_general_path)
            self.add_to_lake_db(sink_sensitive_path)
            logger.info(f'Processed {df_changes.count()} updated rows from {source_path} into stage2/Refined')
        else:
            logger.info(f'No updated rows in {source_path} to process.')
        return df_changes.count()

    def load_csv(self, source_path, header=True):
        """ Loads a csv file as a dataframe based on the path specified """
        options = {'format':'csv', 'header':header}
        df = spark.read.load(self.to_url(source_path), **options)
        return df      

    def load_json(self, source_path, multiline=False):
        """ Loads a json file as a dataframe based on the path specified """
        options = {'format':'json', 'multiline':multiline}
        df = spark.read.load(self.to_url(source_path), **options)
        return df    

    def pseudonymize(self, df, metadata): #: list[list[str]]):
        """ Performs pseudonymization of the given dataframe based on the provided metadata (in the OEA format).
            For example, if the given df is for an entity called person, 
            2 dataframes will be returned, one called person that has hashed ids and masked fields, 
            and one called person_lookup that contains the original person_id, person_id_pseudo,
            and the non-masked values for columns marked to be masked.           
            The lookup table should be written to a "sensitive" folder in the data lake.
            eg, df_pseudo, df_lookup = oea.pseudonymize(df, metadata)
            [More info on this approach here: https://learn.microsoft.com/en-us/azure/databricks/security/privacy/gdpr-delta#pseudonymize-data]
        """
        salt = self._get_salt()
        df_pseudo = df
        df_lookup = df
        for row in metadata:
            col_name = row[0]
            dtype = row[1]
            op = row[2]
            if op == "hash-no-lookup" or op == "hnl":
                # This means that the lookup can be performed against a different table so no lookup is needed.
                df_pseudo = df_pseudo.withColumn(col_name, F.sha2(F.concat(F.col(col_name), F.lit(salt)), 256)).withColumnRenamed(col_name, col_name + "_pseudonym")
                df_lookup = df_lookup.drop(col_name)           
            elif op == "hash" or op == 'h':
                df_pseudo = df_pseudo.withColumn(col_name, F.sha2(F.concat(F.col(col_name), F.lit(salt)), 256)).withColumnRenamed(col_name, col_name + "_pseudonym")
                df_lookup = df_lookup.withColumn(col_name + "_pseudonym", F.sha2(F.concat(F.col(col_name), F.lit(salt)), 256))
            elif op == "mask" or op == 'm':
                df_pseudo = df_pseudo.withColumn(col_name, F.lit('*'))
            elif op == "partition-by":
                pass # make no changes for this column so that it will be in both dataframes and can be used for partitioning
            elif op == "no-op" or op == 'x':
                df_lookup = df_lookup.drop(col_name)
        return (df_pseudo, df_lookup)

    def add_to_lake_db(self, source_entity_path):
        """ Adds the given entity as a table (if the table doesn't already exist) to the proper lake db based on the path.
            This method will also create the lake db if it doesn't already exist.
            eg: add_to_lake_db('stage2/Ingested/contoso_sis/v0.1/students')

            Note that a spark db that points to source data in the delta format can't be queried via SQL serverless pool. More info here: https://docs.microsoft.com/en-us/azure/synapse-analytics/sql/resources-self-help-sql-on-demand#delta-lake
        """
        source_dict = self.parse_path(source_entity_path)
        db_name = source_dict['ldb_name']
        spark.sql(f'CREATE DATABASE IF NOT EXISTS {db_name}')
        spark.sql(f"create table if not exists {db_name}.{source_dict['entity']} using DELTA location '{self.to_url(source_dict['entity_path'])}'")

    def drop_lake_db(self, db_name):
        """ Deletes the lake db by the given name """
        spark.sql(f'DROP DATABASE IF EXISTS {db_name} CASCADE')
        result = "Database dropped: " + db_name
        logger.info(result)
        return result

    def create_sql_db(self, source_path):
        """ Prints out the sql script needed for creating a sql serverless db and set of views. """
        source_dict = self.parse_path(source_path)
        db_name = source_dict['sdb_name']
        cmd = '-- Create a new sql script then execute the following in it:\n'
        cmd += f"IF NOT EXISTS (SELECT * FROM sys.databases WHERE name = '{db_name}')\nBEGIN\n  CREATE DATABASE {db_name};\nEND;\nGO\n"
        cmd += f"USE {db_name};\nGO\n\n"
        cmd += self.create_sql_views(source_dict['entity_parent_path'])
        print(cmd)

    def create_sql_views(self, source_path):
        """ Returns the SQL script required to create views for the entities under the given path """
        cmd = ''      
        dirs = self.get_folders(source_path)
        for table_name in dirs:
            cmd += f"CREATE OR ALTER VIEW {table_name} AS\n  SELECT * FROM OPENROWSET(BULK '{self.to_url(source_path)}/{table_name}', FORMAT='delta') AS [r];\nGO\n"
        return cmd 

    def drop_sql_db(self, db_name):
        """ Prints the SQL script required to delete the db by the given name """
        cmd = '-- Create a new sql script then execute the following in it. Alternatively, you can click on the menu next to the SQL db and select "Delete"\n'
        cmd += '-- [Note that this does not affect the data in the data lake - this will only delete the sql db that points to that data.]\n\n'
        cmd += f'DROP DATABASE {db_name}'
        print(cmd)       

class DataLakeWriter: 
    """ Utility class to write data to ADLS.
    """
    def __init__(self, root_destination):
        self.root_destination = root_destination

    def write(self, path_and_filename, data_str, format='csv'):
        mssparkutils.fs.append(f"{self.root_destination}/{path_and_filename}", data_str, True) # Set the last parameter as True to create the file if it does not exist

oea = OEA()